CuTe DSL - Notes

Block 1 — Layouts and Tensors in CUTE

What a layout is

A CUTE layout is a pair: a shape and a stride. Together they define a pure function from a coordinate tuple to an integer offset. A layout by itself describes how to walk an indexed grid — it does not point to memory, hold data, or know about pointers.

layout = make_layout(shape=(4, 3), stride=(1, 4))
# coord (i, j) maps to offset i*1 + j*4
# coord (2, 1) -> offset 6

Layouts are almost always JIT-time (compile-time) objects. The CUTE DSL knows their shape and stride at the moment it generates code, which is what lets the compiler unroll loops, fold offsets into constants, and statically validate partitioning. This is also why you see cutlass.const_expr(...) and range_constexpr everywhere — they mark values that must be resolvable at JIT-time.

JIT vs compile vs runtime

These three terms are easy to conflate, so worth pinning down:

  • Runtime — the kernel is executing on the GPU, on actual data.
  • Compile time — a compiler (e.g., nvcc, LLVM) is producing machine code, before the program runs at all.
  • JIT (Just-In-Time) — a compile step that happens during program execution but before the kernel runs on the GPU.

CUTE DSL is JIT: when cute.compile(fmha, ...) is called, the Python process introspects the @cute.jit functions, builds an IR, and lowers it to PTX/SASS. After that, the compiled kernel can be launched and runs on the GPU at runtime. So in CUTE, “compile-time” and “JIT-time” mean the same thing in practice.

Hierarchical shapes

Shapes can be nested, and this is the move that makes the rest of CUTE possible:

shape  = ((2, 2), 3)
stride = ((1, 4), 8)
# coord ((i0, i1), j) -> offset i0*1 + i1*4 + j*8

The outer mode has rank 2 (the nested pair (2, 2) and the scalar 3). Each inner sub-shape is itself a coordinate component.

A flat mode has one stride. A nested (hierarchical) mode is a tuple of sub-shapes, each with its own stride; the offset for that mode is the sum of all sub-stride contributions. Concretely, suppose “the M axis” of size 64 is described as \(M = (8, 2, 4)\) with strides \((1, 16, 32)\). A logical M coordinate \(m \in [0, 64)\) is decomposed into \((m_0, m_1, m_2)\) with \(m = m_0 + 8 m_1 + 16 m_2\), and the memory offset is:

\[\text{offset}(m) = m_0 \cdot 1 + m_1 \cdot 16 + m_2 \cdot 32\]

So one logical M coordinate pulls contributions from three different strides. This is exactly how WGMMA’s M axis is described — the hierarchy is the hardware geometry, and the question “is the M dimension a single contiguous stride?” no longer has a yes/no answer; it has a hierarchical answer.

Operations on layouts used in K4
  • cute.size(layout) — total element count.
  • cute.size(layout, mode=[k]) — size of the k-th axis (handles hierarchy).
  • cute.rank(layout) — number of top-level modes.
  • cute.slice_(layout, (None, None, 0)) — fix some coords, leave others free; returns a new layout.
  • cute.append(a, b) — extend a layout with another mode.
  • cute.flat_divide(tensor, tile_shape) — partition a tensor into tiles. After flat_divide(mQ, (Br, d)), the result has shape (Br, d, n_M_tiles, n_K_tiles, batch) — tile interior at the front, tile coordinates at the back.
What a tensor is

A tensor in CUTE is an (iterator, layout) pair. The iterator is a pointer-like handle that holds a base address (with type info); it is not a Python iterator in the for x in it sense — closer to a C++ pointer. Indexing a tensor at coord \(c\) means: compute \(\text{layout}(c)\), add it to the iterator, dereference.

\[\text{tensor}[c] \;=\; \ast(\text{iterator} + \text{layout}(c))\]
tensor = cute.make_tensor(some_pointer, my_layout)
val = tensor[(2, 1)]   # equivalent to *(some_pointer + my_layout(2, 1))
Aliasing: two views into the same storage

A direct consequence of tensor = (iterator, layout) is that two tensors can share an iterator but use different layouts — they are two different views of the same registers/SMEM/GMEM, with no data motion. K4 uses this trick repeatedly:

acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape)
acc_pv_mn = cute.make_tensor(
    acc_pv.iterator,
    self.layout_acc_mn(pv_tiled_mma, acc_pv.layout),
)

acc_pv is the WGMMA-native fragment; its layout is whatever hierarchical structure the SM90 atom expects. acc_pv_mn points to the same physical registers, but under a layout where mode-0 is “row in this thread’s view of the tile” and mode-1 is “column.” Writing acc_pv_mn[i, j] = x is just sugar over a register store, indexed in a way humans can reason about. The same alias-the-iterator pattern shows up for acc_qk_mn, tOcO_mn, and inside make_acc_into_op.

Register fragments and the tile

When make_fragment_C is called, CUTE allocates a per-thread register tensor. Two important consequences:

  • A WGMMA output tile is a single shared object — say \(64 \times 64 = 4096\) elements — produced by the entire warpgroup of 128 threads. Each thread holds \(4096 / 128 = 32\) of those elements in its registers. These 32 elements are not a contiguous sub-rectangle of the tile; they are scattered across it in a pattern fixed by the WGMMA hardware.
  • Therefore “thread T’s acc_pv_mn[0, 0]” and “thread T+1’s acc_pv_mn[0, 0]” refer to different positions in the 64×64 output tile, because they live in different threads’ registers.

The mn-view exists for human convenience: it lets you write acc_pv_mn[i, j] and pretend you are indexing rows and columns of a 2D fragment, but the indexing is into your thread’s local 32 registers. The mapping

\[(\text{thread\_id},\; i,\; j) \;\longrightarrow\; (M,\; N)\]

— that is, from a (thread, fragment_coord) pair to a position in the actual output tile — is encoded in the layout itself, specifically in tv_layout_C. This is the object the next block will unpack.

Why this matters for K4

The WGMMA C-layout on SM90 makes it so that as a single thread sweeps \(i\) from 0 to n_rows - 1 in its mn-view, it visits multiple different M rows of the output tile, not one row. Per-thread row-indexed softmax state — one s_max[i], one a_sum[i] per local row \(i\) — is therefore correct only if \(i\) truly corresponds to distinct output rows in this thread’s view, which the register-softmax design ensures by sizing those arrays to cute.size(acc_pv_mn, mode=[0]). The earlier SMEM softmax assumed row = tidx // 2 (one row per thread, scalar m_i), which is wrong against the actual WGMMA C-layout — and that is the layout-mismatched-scaling fingerprint described in the bug context.


Block 2 — TV Layouts: Thread/Value Partitioning of a Tile

This block builds on Block 1, where we established that a tensor is (iterator, layout) and that a register fragment is just a per-thread layout over physical registers. The missing piece was: how is the tile divided up across threads in the first place? That is what a TV layout specifies.

Part A — General foundations
What a TV layout is

A TV layout is a CUTE layout whose two top-level modes are interpreted as \((T, V)\) — Thread index and Value index within that thread. It is a function

\[(T,\; V) \;\longrightarrow\; \text{tile\_position}\]

that says, for every thread \(T\) and every value slot \(V\) in that thread’s fragment, which position in the underlying tile that pair refers to. It is still just a CUTE layout — shape and stride — but the convention is that the first mode counts threads and the second counts values per thread.

Every tiled MMA exposes three TV layouts, one per operand:

tiled_mma.tv_layout_A   # how threads partition the A operand tile
tiled_mma.tv_layout_B   # how threads partition the B operand tile
tiled_mma.tv_layout_C   # how threads partition the C output tile
The two axes of a TV layout

For any TV layout:

  • tv_layout.shape[0] — Thread mode (how many threads, possibly hierarchical).
  • tv_layout.shape[1] — Value mode (how many values per thread, possibly hierarchical).
  • tv_layout.stride[0] — strides telling you how each bit of the thread index moves you in the tile.
  • tv_layout.stride[1] — strides telling you how each bit of the value index moves you in the tile.

The crucial property of WGMMA C-layouts: both the Thread mode and the Value mode contribute to both M and N of the tile. Some sub-bits of the thread index walk you along M, others along N; same for the value index. This is why a single thread’s fragment is not a clean sub-rectangle of the tile — it is scattered across the tile in a hardware-defined pattern.

A concrete example: SM90 WGMMA, M = N = 64

For an SM90 WGMMA atom with \(M = N = 64\), the C-layout shape is roughly:

tv_layout_C.shape  =  ( (4, 32),   ((2, 2), 8) )
                       ^^^^^^^      ^^^^^^^^^^
                       Threads (T)  Values (V)

128 threads (\(= 4 \times 32\), one warpgroup of 4 warps × 32 lanes) each holding 32 values (\(= 2 \times 2 \times 8\)). The strides — the actual mapping into the tile — are hardware-defined and what tv_layout_C.stride carries.

Per thread, this typically resolves to:

Quantity Value Meaning
Total threads 128 one warpgroup
Total tile elements 4096 \(64 \times 64\)
Values per thread 32 \(4096 / 128\)
Distinct M rows / thread 2 mn-view mode 0
Distinct N cols / thread 16 mn-view mode 1
Threads sharing a row 4 reduce across these for row-max / row-sum
Threads on different rows 32 \(128 / 4\)
The “stride < M means walks-along-M” test

Picture the tile flattened M-major into a 1D buffer of 4096 entries. Addresses 0..63 form M-row 0, addresses 64..127 form M-row 1, and so on. Then:

  • A stride \(< M\) moves you within a single row — it walks along M’s contiguous direction.
  • A stride \(\geq M\) moves you across rows — it walks along N.

This is the criterion CUTE uses to peel apart M-contributions and N-contributions of any hierarchical mode. We will use it twice — once on the value mode (to build a clean mn-view of a fragment) and once on the thread mode (to find the threads that share a row).

Part B — Building on these foundations for K4
layout_separate: the M/N peeler
@staticmethod
def layout_separate(thr, src, ref):
    lt = cute.make_layout(())
    ge = cute.make_layout(())
    for k, v in enumerate(ref):
        if cutlass.const_expr(v < thr):
            lt = cute.append(lt, src[k])
        else:
            ge = cute.append(ge, src[k])
    ...

Ignoring the rank-1 packaging boilerplate, this takes a layout src and a parallel stride list ref, and splits src’s sub-modes into two groups using the criterion above — sub-modes whose corresponding stride in ref is \(<\) thr go into lt, the rest go into ge. In K4, thr is always tiled_mma.shape_mnk[0], which is \(M\). So lt is the M-walking part and ge is the N-walking part.

layout_separate is the surgical tool: given a hierarchical mode that mixes M and N contributions, peel it apart into (M-part, N-part).

layout_acc_mn: the friendly mn-view
@cute.jit
def layout_acc_mn(self, tiled_mma, acc):
    separated = self.layout_separate(
        tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
    )
    V_M = separated[0]
    V_N = separated[1]
    ...
    return cute.append(... V_M1, V_N1 ...)

acc[0] is the value-mode of the per-thread accumulator fragment — the layout describing those 32 elements one thread holds. The reference list it splits against is tiled_mma.tv_layout_C.stride[1], the value-mode stride of the C TV layout. The reasoning:

  • acc[0] and tv_layout_C.shape[1] describe the same thing — the per-thread value mode — because the accumulator fragment was built from the C TV layout, so its value mode shares structure with tv_layout_C.shape[1], and the canonical strides for those sub-modes are tv_layout_C.stride[1].
  • The [1] is because tv_layout_C.shape = (Threads, Values), so [0] is threads and [1] is values. Building the value part of the mn-view obviously needs the value stride.

After this transform, indexing acc_pv_mn[i, j]:

  • \(i\) walks only along M-bits of the fragment → distinct M rows of the tile.
  • \(j\) walks only along N-bits of the fragment → distinct N columns of the tile.

And cute.size(acc_pv_mn, mode=[0]) gives exactly the number of distinct M rows this thread owns — the right size for s_max[i], a_sum[i], s_max_prev[i]. For an SM90 64×64 atom, this is n_rows = 2, n_cols = 16.

So acc_pv_mn[1, 5] for thread 0 means “the second of thread 0’s two owned M rows, in the sixth of thread 0’s owned N columns” — a single tile position, not a pair of rows. The exact (M, N) indices in the tile are determined by tv_layout_C’s strides; sweeping i from 0 to 1 visits two distinct M rows of the tile, and sweeping j from 0 to 15 visits 16 distinct N columns.

reduction_target_n: the threads that share a row
@cute.jit
def reduction_target_n(self, tiled_mma):
    separated = self.layout_separate(
        tiled_mma.shape_mnk[0],
        cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
        tiled_mma.tv_layout_C.stride[0],
    )
    return separated[1]   # the N-part of the thread mode

Same layout_separate machinery, applied this time to the thread mode (tv_layout_C.shape[0] and tv_layout_C.stride[0]). It returns the part of the thread index whose stride is \(\geq M\) — i.e., the bits of the thread index that walk you along N. Threads whose thread-id differs only in those bits are siblings on the same row; they hold different N-columns of the same M-rows.

Why this matters for a row-max. A row-max is a max over all N columns of a fixed M row, so the reduction must run across all elements that share the same M coordinate but differ in N. In the WGMMA C-layout, those elements are split between:

  • Within one thread: 16 N columns owned for a given local row \(i\) — handled by the linear inner loop over \(j\).
  • Across threads: the remaining N columns of that row, held by sibling threads — handled by the cross-thread reduce, which is exactly reduction_target_n.

Reducing across the M-part of the thread mode would instead combine values from different M rows — a column-max-across-rows, completely wrong for softmax. The N-part is the only valid reduction group.

The K4 code:

for r in cutlass.range_constexpr(red_rank_qk):
    for i in cutlass.range_constexpr(n_rows):
        s_max[i] = cute.arch.warp_reduction_max(
            s_max[i],
            threads_in_group=reduction_target_qk.shape[r],
        )

iterates over the sub-modes of reduction_target_n’s output and does a __shfl_xor-style reduction whose group size is each sub-mode’s extent. For a typical SM90 atom, red_rank is 1 and shape[0] is 4, so the line reduces across 4 sibling threads. After this stage, every thread that owns row \(i\) holds the same global row-max in s_max[i], which is what makes the per-row rescale of acc_pv_mn[i, *] valid.

Why this nails the original SMEM-softmax bug

The original SMEM design assumed row = tidx // 2 — one row per thread, scalar m_i, l_i, alpha. Against the actual tv_layout_C this is wrong on two counts:

  1. Each thread actually owns 2 rows, not 1. A single per-thread scalar is missing one degree of freedom from the start.
  2. Each row is shared by 4 threads, not by 2. A warp_reduction(threads_in_group=2) reduces across the wrong neighborhood.

When acc_pv_mn[ii, jj] *= alpha was applied with one alpha per thread, the loop over ii ran across rows that needed different alphas — and only one was available. Tile elements at (row 0, col j) and (row 1, col j) were scaled by the same factor when their softmax factors are unrelated. Hence the “essentially every element wrong, but finite” fingerprint.

Why the register-softmax fix is structurally correct

The fix sizes s_max, a_sum, s_max_prev to n_rows = cute.size(acc_pv_mn, mode=[0]), so per-row state matches the actual number of M rows each thread owns. The cross-thread reduce uses reduction_target_n(tiled_mma).shape[r] rather than a hardcoded count, so it adapts to the atom’s TV layout. And because nothing in the softmax math is hardcoded, an atom with a different n_rows (e.g., 4) would Just Work — only register pressure scales with row count.

That structural correctness is why the persistence of the same bad ≈ 33.5M / 33.5M, max_abs ≈ 32.6 fingerprint after the rewrite is meaningful: the remaining bug is almost certainly outside the softmax math — somewhere in the QK→PV operand handoff or the output scatter — not in the per-row scaling itself.


Block 3 — WGMMA Operand Semantics

This block builds on Block 2 (TV layouts) by asking what the warpgroup-level matrix-multiply-accumulate instruction (WGMMA) actually consumes and produces, and how CUTE wraps that hardware.

Part A — General foundations
What WGMMA is

WGMMA is the SM90 warpgroup-level MMA instruction. One issue computes

\[C \mathrel{+}= A \cdot B\]

at warpgroup granularity (128 threads, 4 warps). The atom shape is typically \(64 \times N \times 16\) for SM90 BF16 (where \(N \in \{8, 16, \dots, 256\}\)). Three properties drive the rest of the design:

  • Async. WGMMA does not stall the warpgroup. You issue, you keep going, you fence/wait when you need the result.
  • Operand sources are constrained:
    • A can come from SMEM or registers (RF).
    • B can come only from SMEM.
    • C is always the register accumulator (RF).
  • Each operand has its own TV layout. tv_layout_A, tv_layout_B, tv_layout_C are independent objects — the per-thread layout for an A-operand is not the same as the per-thread layout for a C-accumulator, even when both happen to live in registers.

The combinatorial space for a WGMMA instruction (per dtype/shape):

  • A source: SMEM or RF → 2.
  • A major mode: K-major or MN-major → 2.
  • B major mode: K-major or MN-major → 2.

That gives 8 distinct WGMMA variants. Swizzle pattern is not a separate axis — it is coupled to the major mode (each major mode admits a small set of swizzle patterns: none, 32B, 64B, 128B), encoded into the SMEM descriptor at issue time. There is no “B from RF” variant, which is why it’s 8 and not 16.

Atom shape vs. per-thread fragment

The WGMMA atom is the smallest hardware-supported MMA shape — but it is the shape of the whole warpgroup-level tile, not the per-thread piece. For a 64×64 atom:

  • One atom = one WGMMA instruction = 4096 output elements computed cooperatively by 128 threads.
  • Per-thread fragment = \(4096 / 128 = 32\) elements held in this thread’s registers.

So 32 elements per thread is exactly consistent with a 64×64 atom — the fragment is the per-thread share of the atom’s output, not “an atom by itself.”

Why per-thread fragments matter even though Tensor Cores do the math

Tensor Cores execute the multiply-add, but software is still responsible for two things:

  • Reading and writing the accumulator (C). The accumulator lives in registers, which are per-thread. Hardware fixes which elements of the output tile each thread holds — the C TV layout. Any post-processing (softmax, store to GMEM, feed into another WGMMA) must know which tile elements live in this thread.
  • Preparing register-sourced A operands. When A comes from SMEM, the WGMMA hardware reads SMEM directly via a swizzled descriptor — software doesn’t shuffle bytes per thread. But when A is sourced from registers (as in PV’s A), each thread must have its share of A sitting in the exact physical registers the WGMMA instruction expects, in the layout described by tv_layout_A.
SMEM descriptors

An SMEM descriptor is a 64-bit value encoding where an operand tile starts in SMEM, the swizzle pattern, the leading-dim stride, and a few flags. The descriptor describes the operand tile for one WGMMA instruction — e.g., the 64×16 slice of A and the 16×N slice of B for a single WGMMA call. Larger overall operands (full-K matmuls) require multiple WGMMA instructions chained by a k_block loop, each with its own descriptor pointing at the next K-slice.

The async pattern: fence / commit_group / wait_group

Three primitives at three different roles:

  • fence() — does not wait. A one-instruction register fence: tells the WGMMA pipeline that the current state of registers is final and may be consumed. No blocking. Acts as a memory barrier for register operands feeding WGMMA (and for the accumulator if it was just written by other code).
  • commit_group() — does not wait. Bundles all WGMMAs issued since the last commit into a named group (“tag this batch as group #42”).
  • wait_group(N)this is the one that blocks. Waits until at most \(N\) committed groups remain in flight. wait_group(0) blocks until all prior commits have completed, after which all results are in registers and safe to read.

The argument \(N\) exists for pipelining. In K6-style ping-pong you might issue group A, commit, issue group B, commit, then wait_group(1) — meaning “wait until at most 1 group is still in flight, i.e., wait for A but let B keep running.” Then you process A’s results while B computes. K4 does not pipeline at this level and uses wait_group(0) everywhere.

wait_group synchronizes the issuing warpgroup with the WGMMA pipeline; it does not synchronize across warpgroups. Other warps doing other work are untouched. Warp specialization (K5+) needs different sync primitives.

The standard pattern:

fence()                          # register operands ready
for k_block in K-blocks: gemm()  # issue WGMMAs (async)
commit_group()                   # close this batch
wait_group(0)                    # wait for it
CUTE’s wrapping

CUTE exposes a few related objects:

  • A tiled MMA (sm90_utils.make_trivial_tiled_mma(...)): bundles a single WGMMA atom with the partitioning data (TV layouts for A, B, C). One per matmul.
  • A thread MMA (tiled_mma.get_slice(tidx)): a thread-local view that knows how to extract this thread’s slice of A, B, C from the warpgroup-level tile.
  • partition_X (X ∈ {A, B, C}): takes a source tile (SMEM/GMEM) and returns a tensor whose layout is this thread’s slice under operand X’s TV layout. Still points at the original storage; it is a view, not a copy.
  • make_fragment_X: wraps the partitioned tensor in the layout WGMMA’s instruction encoding expects. For SMEM sources, a layout reinterpretation. For RF sources, a register allocation.
  • make_fragment_C(shape): allocates a register tensor of the right size for the per-thread C slice, ready to pass as the accumulator to cute.gemm.
CUTE naming convention

The four-letter pattern t<X><mem><Operand> for partitioned tensors:

  • First letter = matmul this is for. S for QK (output is S, the score matrix). O for PV (output is O).
  • Second letter = always t (“thread-partitioned”).
  • Third letter = where it lives. s = SMEM, r = registers, g = global memory, c = coordinates (for identity tensors).
  • Fourth letter = which operand. Q, K, V, P, S, O.

So tSrK = “QK matmul, thread-partitioned, register-fragment view, K operand.” tOcO = “PV matmul, thread-partitioned, coordinate tensor, O destination.” Convention only — not enforced — but used consistently throughout CUTE.

Part B — Building on these foundations for K4
The two tiled MMAs

K4 has two matmuls and so two tiled MMAs.

QK (\(S = Q K^\top\), M=Br, N=Bc, K=d):

qk_tiled_mma = sm90_utils.make_trivial_tiled_mma(
    INPUT_DTYPE, INPUT_DTYPE,
    warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K,
    ACC_DTYPE,
    self.atom_layout_mnk,
    (self.Br, self.Bc),
    warpgroup.OperandSource.SMEM,
)

For Q (A operand), the contiguous axis is \(d = K\), so K-major. K (B operand) has the same logical layout, so also K-major. A from SMEM — standard FA pattern.

PV (\(O \mathrel{+}= P V\), M=Br, N=d, K=Bc):

pv_tiled_mma = sm90_utils.make_trivial_tiled_mma(
    INPUT_DTYPE, INPUT_DTYPE,
    warpgroup.OperandMajorMode.K,
    warpgroup.OperandMajorMode.MN,
    ACC_DTYPE,
    self.atom_layout_mnk,
    (self.Br, self.d),
    warpgroup.OperandSource.RMEM,
)

Two important differences:

  1. A from RMEM — so freshly-computed P (in acc_qk registers) can feed PV directly without a SMEM round-trip. Handoff via make_acc_into_op.
  2. B is MN-major — V’s natural storage has \(d\) contiguous. For PV, \(N = d\), so V is MN-major for PV. This is a major-mode trap: the same physical V tensor is K-major when consumed as K (in QK) and MN-major when consumed as V (in PV). The auto-detector v_layout_enum.sm90_mma_major_mode() returns the same answer as for K (because the strides are identical) and would incorrectly say K-major. The K4 code hardcodes OperandMajorMode.MN to override this.

The same logic appears in the SMEM atom selection (make_smem_layout_b(LayoutEnum.COL_MAJOR, ...) for V) — picks the MN-major SMEM atom and the matching tile order.

QK partitioning and the inner k_block loop
qk_thr_mma = qk_tiled_mma.get_slice(tidx)
tSsQ = qk_thr_mma.partition_A(sQ_full)        # SMEM view, this thread's slice
tSsK = qk_thr_mma.partition_B(sK_full)
tSrQ = qk_tiled_mma.make_fragment_A(tSsQ)     # WGMMA-encoded A operand
tSrK = qk_tiled_mma.make_fragment_B(tSsK)
qk_acc_shape = qk_thr_mma.partition_shape_C((self.Br, self.Bc))

Both A and B come from SMEM, so tSrQ and tSrK are layout reinterpretations, not register allocations. partition_shape_C returns the shape of the per-thread C slice; the actual register allocation happens inside the loop:

acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape)
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range_constexpr(num_k_blocks_qk):
    qk_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0)
    cute.gemm(qk_tiled_mma, acc_qk,
              tSrQ_k[(None, None, k_block_idx)],
              tSrK_k[(None, None, k_block_idx)],
              acc_qk)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(0)

Standard pattern: fence, loop over K-blocks, set ACCUMULATE=False on the first issue (write) and True for the rest (add), commit, wait. The ACCUMULATE flag is a per-instruction property of WGMMA — False means “\(C = A \cdot B\)”, True means “\(C \mathrel{+}= A \cdot B\)”.

Two scales of “K”

K4 has two contraction-like loops at two different scales, and they are easy to confuse because both are called “K”:

  • Outer KV iterations (FlashAttention’s sequence chunking). Full seqlen \(N\) broken into chunks of \(B_c\). For \(N=512\), \(B_c=64\): \(n_{kv} = 8\) iterations. Each iteration loads one chunk of K and V, runs QK + softmax + PV, accumulates into acc_pv.
  • Inner k_block loop (WGMMA’s contraction-axis chunking). One matmul has a contraction axis (K of the matmul); the WGMMA atom has \(K_{atom} = 16\) (BF16). For \(d = 64\) that’s 4 k_blocks per QK matmul; for \(B_c = 64\) that’s 4 k_blocks per PV matmul. Each k_block is one WGMMA instruction.

For a \(N=512\), \(d=64\), \(B_r = B_c = 64\), \(K_{atom}=16\) problem there are \(8 \text{ KV iters} \times (4 \text{ QK} + 4 \text{ PV}) = 64 \text{ WGMMAs}\) per (batch, head, M_tile) workload.

The QK→PV handoff: make_acc_into_op

After softmax, \(P = \exp(S - s_{\max})\) lives in acc_qk as FP32 under the QK C TV layout. PV wants its A operand in registers, in BF16, under the PV A TV layout. The handoff:

@cute.jit
def make_acc_into_op(self, acc, operand_layout_tv, Element):
    operand = cute.make_rmem_tensor_like(
        self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
        Element,
    )
    operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
    acc_vec = acc.load()
    operand_as_acc.store(acc_vec.to(Element))
    return operand

Three steps:

  1. Allocate an A-layout register fragment of the right BF16 size. convert_c_layout_to_a_layout reshapes the C-fragment layout into the A-fragment layout that physically aligns with how WGMMA expects A. This works because the SM90 WGMMA register layouts for C and A are designed to be reinterpretable into one another without inter-thread shuffles for BF16 — exactly so that back-to-back GEMMs (like QK→PV) can hand off through registers.
  2. Build an aliasing view: operand_as_acc = cute.make_tensor(operand.iterator, acc.layout) — same iterator (same registers), C-shaped layout. One register region, two views.
  3. Cast and store via the C view. Reading is C-shaped (matches acc); writing routes the FP32→BF16 cast into the registers that will later be read as A-shaped by WGMMA.

The function returns operand (the BF16 A-layout fragment). This is fed directly into cute.gemm for PV. convert_c_layout_to_a_layout is a project-level helper, not a CUTE built-in — both K4 and the reference fmha.py define it identically (verified by direct comparison; character-for-character match), copied from CUTLASS C++ examples.

FP8 (E4M3) breaks the no-shuffle property: K4’s make_acc_into_op docstring notes that FP8 needs additional cross-thread shuffles on top.

ACCUMULATE for PV is always True

acc_pv is allocated once, before the mainloop, and zero-initialized. Each KV iteration adds to it: \(O \mathrel{+}= P_j V_j\). Within a single PV matmul, multiple inner k_blocks also accumulate into the same acc_pv. Both reasons require ACCUMULATE=True. (The “we reuse C as A” observation from make_acc_into_op is a separate fact about the QK→PV register handoff, not the cause of the PV accumulate flag.)

Compare to QK: ACCUMULATE = (k_block_idx != 0) — each KV iteration freshly allocates and writes acc_qk, accumulating only across the inner k_block loop.

Why acc_pv lives outside the loop and acc_qk inside
  • acc_pv carries the running attention output across KV iterations — it must be zero-initialized once and updated each iteration. Zero-initialization also makes the iter-0 unified rescale benign: with \(s_{\max,\text{prev}} = -\infty\) and \(s_{\max} =\) finite, \(\text{scale}_{pv} = e^{-\infty} = 0\), and \(0 \cdot 0 = 0\), so the rescale silently does nothing on iter 0.
  • acc_qk is recomputed from scratch each KV iteration (it’s the score block for the current K chunk, not a running quantity), so allocating it inside the loop both makes the dependency explicit and lets the compiler reuse registers across iterations.
Where does the softmax denominator go?

The full softmax row is

\[\text{softmax}(S)_{i,j} = \frac{\exp(S_{i,j} - m_i)}{\sum_k \exp(S_{i,k} - m_i)}\]

K4 computes only the numerator \(P_{i,j} = \exp(S_{i,j} - m_i)\) per iteration, accumulating two things in parallel:

  • The unnormalized weighted sum \(O\) in acc_pv.
  • The running row-sum \(\ell\) in a_sum.

Both get rescaled by \(\exp(m_{\text{prev}} - m_{\text{new}})\) when a new row-max is found, so the math stays consistent across iterations. The denominator is applied at the very end:

\[O_{\text{final}}[i, j] = \frac{O[i, j]}{\ell[i]}\]

— the final gO_tile[m_idx, n_idx] = (acc_pv_mn[i, j] * inv_l).to(OUTPUT_DTYPE) in the finalize block.

The output scatter (4.7)

The finalize writes the per-thread accumulator to GMEM. K4 uses an explicit-coordinate style:

cO = cute.make_identity_tensor((self.Br, self.d))
tOcO = pv_thr_mma.partition_C(cO)
tOcO_mn = cute.make_tensor(
    tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout),
)
for i in cutlass.range_constexpr(n_rows):
    inv_l = cute.arch.rcp_approx(a_sum[i])
    if a_sum[i] == 0.0 or a_sum[i] != a_sum[i]:
        inv_l = cutlass.Float32(1.0)
    for j in cutlass.range_constexpr(n_cols_pv):
        m_idx = tOcO_mn[i, j][0]
        n_idx = tOcO_mn[i, j][1]
        gO_tile[m_idx, n_idx] = (acc_pv_mn[i, j] * inv_l).to(OUTPUT_DTYPE)

Two CUTE primitives drive this:

  • cute.make_identity_tensor((Br, d)) creates a logical \(B_r \times d\) tensor where reading cO[m, n] yields the coordinate tuple (m, n) itself. It carries no real storage — purely a coordinate-reporting view, used as a probe.
  • pv_thr_mma.partition_C(cO) applies the same C-operand partitioning to cO that acc_pv lives under. So tOcO[v_idx] for thread T evaluates to the \((M, N)\) tile coord that T’s v_idx-th C-fragment slot would correspond to. Aliased through layout_acc_mn, tOcO_mn[i, j] returns the \((M, N)\) tile coord of T’s local row \(i\), column \(j\). Index-for-index aligned with acc_pv_mn[i, j] — same partitioning, same mn-view.

Per thread, for each owned (i, j): look up the actual \((M, N)\) tile coordinate from tOcO_mn, divide acc_pv_mn[i, j] by a_sum[i], cast to BF16, write to gO_tile[M, N]. With 128 threads each writing 32 slots, every position in the 64×64 output tile is written exactly once.

Idiomatic alternative

The more common CUTE pattern is to partition the destination directly:

tOgO = pv_thr_mma.partition_C(gO_tile)
tOgO_mn = cute.make_tensor(tOgO.iterator,
                           self.layout_acc_mn(pv_tiled_mma, tOgO.layout))
tOgO_mn[i, j] = acc_pv_mn[i, j] * inv_l   # cast as needed

This lets the layout machinery handle the per-thread→GMEM mapping the same way it does for acc_pv. The K4 explicit-coordinate version makes addresses visible (easier to debug) but adds one more place where a layout assumption could go wrong. Whether raw gO_tile[m, n] indexing into a sliced flat_divide result lands at the right physical addresses — given mO’s (seqlen, head_dim, n_heads, batch) strides — is one of the candidates for the remaining bug.


Block 4 — TMA and Pipelines

This block adds the loading half of K4. Block 3 covered what WGMMA consumes once tiles are in SMEM; this block covers how those tiles get there asynchronously, and how the load machinery overlaps with compute via pipelining.

Part A — General foundations
What TMA is

TMA = Tensor Memory Accelerator, the SM90 hardware unit that performs bulk asynchronous copies of multidimensional tiles between GMEM and SMEM. Two properties distinguish it from the SM80 cp.async model:

  • One thread issues the whole copy. A single thread executes the TMA instruction and hardware does the rest. With cp.async, every thread issues its own per-element load.
  • The tile shape is multidimensional and known to hardware. Before the kernel runs, you build a tensor descriptor that captures the source tensor’s shape, strides, and dtype. The hardware uses this descriptor to walk the source tile correctly.

Other key properties:

  • Async, with completion tracked by an mbarrier. When the copy completes, hardware arrives on a barrier with the byte count it transferred.
  • G2S (load) and S2G (store). K4 uses G2S only.
  • Optional multicast. One TMA can broadcast the same tile to multiple CTAs; K4 has num_multicast=1 (no broadcast).
TMA atoms vs SMEM descriptors

These are completely separate hardware-and-software concepts and easy to confuse:

  • A SMEM descriptor is a 64-bit value describing the operand tile of one WGMMA instruction (one k_block). A sequence of WGMMAs walks through SMEM via successive descriptors.
  • A TMA atom describes how to load a tile from GMEM into SMEM. Different hardware unit, different purpose.

WGMMA reads SMEM that TMA wrote; neither knows the other exists. The link between them is purely the SMEM layout — both ends agree on where bytes live.

Building a TMA atom

A TMA atom bundles a TMA-issue helper, a tensor descriptor, and the SMEM destination layout:

tma_atom_q, tma_tensor_q = cute.nvgpu.cpasync.make_tiled_tma_atom(
    cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(),
    mQ, sQ_layout, (self.Br, self.d), num_multicast=1,
)
  • The op type — CopyBulkTensorTileG2SOp says “G2S, single CTA, tile-mode.”
  • mQ — the GMEM source tensor; the descriptor records its shape/strides/dtype.
  • sQ_layout — the SMEM destination layout (per-stage); tells TMA the target tile shape and SMEM swizzle.
  • (Br, d) — the tile shape copied per issue.
  • num_multicast=1 — no broadcast.

Returns:

  • tma_atom_q — opaque handle holding the descriptor and metadata; first arg of cute.copy(...).
  • tma_tensor_q — a re-wrapped view of mQ with TMA-compatible layout. Use this from here on; cute.copy from raw mQ will not work.

cute.nvgpu.cpasync.prefetch_descriptor(tma_atom) prefetches the descriptor into the SM caches so the first TMA issue avoids descriptor-fetch latency. K4 does this for all three operands on warp 0.

tma_partition and the partitioned tensors

Once the atom is built, partition source and destination for issue:

tQsQ, tQgQ = cute.nvgpu.cpasync.tma_partition(
    tma_atom_q, 0, cute.make_layout(1),
    cute.group_modes(sQ_full, 0, 2),
    cute.group_modes(gQ_tiles, 0, 2),
)
  • 0, cute.make_layout(1) — multicast slot index and CTA layout. 0 plus a trivial 1-element layout = “no multicast, this CTA only.”
  • cute.group_modes(sQ_full, 0, 2) — SMEM dest with the first two modes (tile interior \((B_r, d)\)) collapsed. Result: (tile, n_stages).
  • cute.group_modes(gQ_tiles, 0, 2) — GMEM source with first two modes collapsed. Result: (tile, n_M_tiles, …).

Why group_modes? TMA treats “the tile” as a single opaque blob (it knows how to walk it via the descriptor), so collapse the tile-interior modes into one. Outer modes remain as which-tile coordinates.

Returns:

  • tQsQ — partitioned SMEM dest, shape (tma_unit, n_stages).
  • tQgQ — partitioned GMEM source, shape (tma_unit, tile_coord).

Naming convention here is t<Op><mem><Operand>: t for TMA-partitioned, Q for the operand, s or g for SMEM or GMEM.

A copy then uses these:

cute.copy(
    tma_atom_q,
    tQgQ[(None, bidx_m)],                   # GMEM tile to load from
    tQsQ[(None, q_producer_state.index)],   # SMEM stage to load into
    tma_bar_ptr=q_pipeline.producer_get_barrier(q_producer_state),
)

tma_bar_ptr is the mbarrier the TMA will arrive on when complete.

mbarriers, in detail

An mbarrier (“memory barrier”) is a small SMEM object — not a bit. It holds:

  • A byte counter, which TMAs arrive on with the number of bytes they transferred.
  • A phase bit, which flips each time the counter reaches its expected total (tx_count) and resets.
  • An expected arrival count (tx_count), set at init.

The lifecycle of one fire:

  1. Init — tx_count is set to the total bytes expected; the phase is in its waiting state.
  2. TMA issues; hardware copies bytes; on completion, hardware arrives on the barrier with the byte count.
  3. When cumulative arrived bytes ≥ tx_count, the phase flips. This is the “fire.”
  4. consumer_wait blocks until it observes the phase flip. After it returns, the data is fully visible in SMEM.
  5. After the consumer is done, it arrives on a separate barrier (the “empty” barrier of the same stage) so the producer knows the slot can be refilled.

Each pipeline stage has two mbarriers: a “full” barrier (TMA → consumer) and an “empty” barrier (consumer → producer). Together they enforce that stages cycle correctly through the circular buffer.

For an aggregate barrier whose tx_count is the sum of multiple TMAs (e.g., K + V into the same stage), the barrier only fires when all contributing TMAs have arrived. This is how K4 keeps K and V tied to a single stage with a single barrier.

Pipelining: producer/consumer with circular buffers

Without pipelining, you would issue one load, wait, consume, issue the next. The SM is idle whenever it waits. Pipelining lets you keep N loads in flight simultaneously so the consumer always finds data ready.

The data structure is a circular buffer of N stages in SMEM. With N stages, the producer can be up to N stages ahead of the consumer.

Producer (the warp issuing TMAs):

  • producer_acquire(state) — waits until stage state.index is empty (consumer has released it).
  • Issue the TMA into stage state.index, with the stage’s “full” barrier as the completion mbarrier.
  • producer_commit(state) — no-op for TMA (TMA itself signals the full barrier on completion); explicit arrive for cp.async.
  • state.advance() — move to next stage.

Consumer (the warps doing WGMMA):

  • consumer_wait(state) — wait on stage state.index’s full barrier.
  • Use the data in stage state.index.
  • consumer_release(state) — arrive on stage state.index’s empty barrier so producer can refill.
  • state.advance() — move to next stage.

Steady state. With N stages and a long enough loop, the producer and consumer settle into a steady offset where every consumer release is matched by a new producer issue, keeping the buffer filled at depth N. The first N iterations are warmup (consumer is consuming pre-loaded data); the last N iterations are drain (producer has nothing left to issue).

The N stages are the latency-hiding budget. Bounded by SMEM size (each stage is a full tile) and by the number of barriers managed.

Pipeline state

A pipeline state carries two fields:

  • index — current stage in the circular buffer; wraps modulo num_stages (e.g., for num_stages=5: 0,1,2,3,4,0,1,2,…).
  • count — total stages this state has advanced through, not wrapping. Monotonic — useful for “have I done N iterations yet?” checks.

Producer and consumer have separate states. Both start at (index=0, count=0) and walk forward independently; the barriers ensure they never desync past num_stages.

Part B — Building on these foundations for K4
The two pipelines: Q and KV

K4 has two pipelines because the operands have different reuse patterns.

q_pipeline = pipeline.PipelineTmaAsync.create(
    barrier_storage=storage.q_pipeline_array_ptr.data_ptr(),
    num_stages=self.q_stage,                                       # 1
    producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
    consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps),
    tx_count=self.q_bytes,
    cta_layout_vmnk=cute.make_layout((1, 1, 1, 1)),
    defer_sync=True,
)
  • q_stage = 1 — Q is loaded once per CTA and reused across all KV iterations, so one stage suffices.
  • kv_stage = 5 — K and V are consumed and replaced each iteration; multiple stages let the producer prefetch ahead.
  • producer_group=Agent.Thread — exactly one thread (warp 0’s first thread) issues TMAs; matches TMA’s hardware issue model.
  • consumer_group=Agent.Thread, num_warps — all \(\text{num\_warps} \times 32 = 128\) threads consume; they all execute WGMMA and need to see the data.
  • tx_count=q_bytes — exactly the bytes per Q stage; the barrier fires when this many have arrived.
  • defer_sync=True — defers cross-CTA init sync; allowed because pipeline_init_arrive / pipeline_init_wait is called manually.

The KV pipeline is identical structure but with num_stages=kv_stage=5 and:

kv_bytes = (cute.size_in_bytes(INPUT_DTYPE, sK_layout)
          + cute.size_in_bytes(INPUT_DTYPE, sV_layout))

tx_count = kv_bytes is the sum of K-bytes and V-bytes per stage. K and V land in separate SMEM buffers but share a stage index and a single barrier. The barrier only fires when both TMAs have arrived (cumulative bytes ≥ tx_count). If tx_count were set to just sK_bytes, the barrier would fire after K alone, and the consumer would race with V’s still-in-flight load.

K and V in one stage, two buffers

K and V live in separate SMEM buffers (storage.sK, storage.sV) but share a stage index. The stage is just a slot number; stage 3 of K and stage 3 of V are filled together and consumed together.

tSrK_k = tSrK[(None, None, None, kv_consumer_state.index)]   # K slice
tOrV_k = tOrV[(None, None, None, kv_consumer_state.index)]   # V slice, same index

Same number, different SMEM regions. The pipeline doesn’t care that two physical tiles are involved — it tracks “is stage 3 ready” via one barrier with tx_count = sK_bytes + sV_bytes.

Q load: one-shot
if warp_idx == 0:
    q_pipeline.producer_acquire(q_producer_state)
    cute.copy(tma_atom_q,
              tQgQ[(None, bidx_m)],
              tQsQ[(None, q_producer_state.index)],
              tma_bar_ptr=q_pipeline.producer_get_barrier(q_producer_state))
    q_pipeline.producer_commit(q_producer_state)
    q_producer_state.advance()

q_pipeline.consumer_wait(q_consumer_state)

bidx_m is the M-tile index for this CTA. Q is indexed by M only; each CTA loads one Q tile and reuses it. After the consumer wait, Q is in SMEM and never reloaded — q_consumer_state.index = 0 for the entire kernel.

KV prefetch

Before the mainloop, warp 0 issues up to kv_stage KV loads to fill the pipeline:

prefetch_count = cutlass.min(self.kv_stage, n_kv)
if warp_idx == 0:
    for _ in cutlass.range(prefetch_count, unroll=1):
        kv_pipeline.producer_acquire(kv_producer_state)
        cute.copy(tma_atom_k,
                  tKgK[(None, kv_producer_state.count)],
                  tKsK[(None, kv_producer_state.index)],
                  tma_bar_ptr=kv_pipeline.producer_get_barrier(kv_producer_state))
        cute.copy(tma_atom_v,
                  tVgV[(None, kv_producer_state.count)],
                  tVsV[(None, kv_producer_state.index)],
                  tma_bar_ptr=kv_pipeline.producer_get_barrier(kv_producer_state))
        kv_pipeline.producer_commit(kv_producer_state)
        kv_producer_state.advance()

Two distinct uses of the state:

  • kv_producer_state.count is the GMEM tile index — which KV chunk to load (0, 1, 2, …, \(n_{kv}-1\)).
  • kv_producer_state.index is the SMEM stage — which slot to write into (cycles 0..\(kv_{stage}-1\)).

K and V are loaded into the same stage; the state advances once per pair, not once per TMA. Both TMAs target the same barrier, which fires only when both have arrived (per tx_count = sK_bytes + sV_bytes).

Mainloop: consume + reissue
for _j in cutlass.range(n_kv, unroll=1):
    kv_pipeline.consumer_wait(kv_consumer_state)
    # ... QK + softmax + PV using kv_consumer_state.index ...
    kv_pipeline.consumer_release(kv_consumer_state)
    kv_consumer_state.advance()

    if warp_idx == 0 and kv_producer_state.count < n_kv:
        kv_pipeline.producer_acquire(kv_producer_state)
        cute.copy(tma_atom_k, ...)
        cute.copy(tma_atom_v, ...)
        kv_pipeline.producer_commit(kv_producer_state)
        kv_producer_state.advance()

Steady state at kv_stage = 5: first 5 iterations consume pre-loaded data; from iteration 5 the producer and consumer move in lockstep with the buffer at depth 5; last 5 iterations the producer is idle (count reached n_kv).

The consumer release fires after the PV WGMMA’s wait_group(0), so V is fully consumed by then — the K-then-V usage within one stage is safe.

Where this layer’s bugs could live

The pipeline machinery is well-trodden, but two K4-specific concerns:

  • V partitioning consistency. The same V tile is partitioned twice with two different intents: tVsV (TMA → SMEM, via tma_partition) for loading, and tOsV = pv_thr_mma.partition_B(sV_full) for consumption by PV’s WGMMA. The load uses LayoutEnum.COL_MAJOR to force the MN-major SMEM atom; the consume uses OperandMajorMode.MN on the tiled MMA. If those two choices don’t actually match in the SMEM byte layout they imply, the load lays out V in one pattern and the WGMMA reads it under another. Worth verifying in the bug hunt.
  • tx_count correctness. kv_bytes = sK_bytes + sV_bytes looks right because both TMAs target the barrier. size_in_bytes depends only on element count and dtype, so the forced COL_MAJOR layout for V should still produce the right byte count. Likely not a bug source, but easy to sanity-check.
  • Stage-index sharing. kv_consumer_state.index is reused for both K (in QK) and V (in PV). Structurally fine because they were loaded into the same stage; verified by the release happening after wait_group(0) of PV.

Block 5 — Bug Hunt: Suspect Inventory

The fingerprint and what it means

bad ≈ 33,539,982 / 33,554,432, max_abs ≈ 32.6, values finite. Essentially every output element is wrong, but nothing is NaN or Inf. This is a layout-mismatch fingerprint — values land at the wrong positions, or are scaled by the wrong factors, but the arithmetic itself never blows up. NaN/Inf would imply a softmax-stability bug; we’re not seeing that.

What is already ruled out
  • Register-softmax math. Block 2 established that the rewrite is structurally correct against tv_layout_C: n_rows = cute.size(acc_pv_mn, mode=[0]), cross-thread reduction via reduction_target_n, per-row rescale aligned with the mn-view. Same fingerprint persists after the rewrite, so the bug is not here.
  • QK matmul in isolation. Independently verified in a K4_qk_only experiment.
  • convert_c_layout_to_a_layout and make_acc_into_op contents. Character-for-character match with the reference fmha.py (verified by direct comparison). If something is wrong at this layer, it is in how the output is consumed, not in the function itself.
Suspect 1 — V’s load-vs-consume layout disagreement (HIGHEST priority)

The setup. V is partitioned twice with two different intents:

  • LoadtVsV from cute.nvgpu.cpasync.tma_partition(...) writes V from GMEM into SMEM. The SMEM destination layout was constructed via:
    sV_layout_staged = sm90_utils.make_smem_layout_b(
        utils.LayoutEnum.COL_MAJOR, pv_mma_tile_mnk, INPUT_DTYPE, self.kv_stage,
    )
    

    The LayoutEnum.COL_MAJOR is a deliberate override (per the K4 comment) to pick an MN-major SMEM atom for V.

  • ConsumetOsV = pv_thr_mma.partition_B(sV_full) reads V from SMEM into the WGMMA. pv_tiled_mma was built with OperandMajorMode.MN on B.

Why this is suspicious. Two independent code paths each made an “MN-major” choice. They should describe the same SMEM byte arrangement, but they descend from different sides of the pipeline:

  • The TMA-side choice (LayoutEnum.COL_MAJORmake_smem_layout_b) ends up with a particular SMEM atom + swizzle + tile order. This dictates the byte pattern TMA writes.
  • The MMA-side choice (OperandMajorMode.MN on make_trivial_tiled_mma) configures WGMMA’s SMEM descriptor to read with a particular swizzle and stride.

If those two choices disagree in any sub-detail — swizzle width (32B vs 64B vs 128B), inner-dim ordering, or atom sub-shape — TMA writes one pattern and WGMMA reads under a different pattern. The data is technically there, but at every position it’s the wrong byte. The fingerprint exactly matches: every element wrong, finite, magnitude consistent with “this is some valid V element, just not the right one.”

Why this is the highest-priority suspect. This is the only place in K4 where two independent CUTE primitives had to agree on a non-default SMEM layout choice. Most layout decisions in the kernel are auto-derived from a single source of truth (the operand’s storage layout). V is the exception: K’s storage is identical to V’s, but V needs a different major-mode label, and K4 had to manually override two APIs in two different ways to express that. Every place a manual override is split across two APIs is a place to look.

Diagnostic. The cleanest check is to compare the byte layout K4 builds against the reference. Specifically:

  • What does the reference fmha.py pass to make_smem_layout_b for V? Is it LayoutEnum.COL_MAJOR, or something else (e.g., a derived enum from V’s storage)?
  • Does the reference build pv_tiled_mma with OperandMajorMode.MN exactly, or via a different helper that derives both ends from a single source?
  • If the reference uses a single helper that returns both the SMEM layout and the MMA major-mode in lockstep, that’s what K4 should switch to. Splitting them across two API calls is the structural risk.

A weaker but quicker check: print cute.cosize(sV_layout) and cute.cosize of the WGMMA-expected B operand layout. If the byte counts match but the arrangement differs, you’ll know to look at swizzle/stride. If they don’t match, you’ve found the disagreement directly.

Suspect 2 — Output scatter via raw gO_tile[m, n] indexing

The setup.

gO = cute.flat_divide(mO, (Br, d))
gO_tile = gO[(None, None, bidx_m, 0, bidx_b)]
...
for i, j:
    m_idx = tOcO_mn[i, j][0]
    n_idx = tOcO_mn[i, j][1]
    gO_tile[m_idx, n_idx] = (acc_pv_mn[i, j] * inv_l).to(OUTPUT_DTYPE)

Why this is suspicious. gO_tile is what falls out of flat_divide after slicing. flat_divide produces a layout whose modes are tile-interior + tile-coords, and slicing fixes the tile-coords. The result has shape (Br, d) but its stride is inherited from the parent tensor mO’s storage layout — which has shape (seqlen, head_dim, n_heads, batch) and non-trivial strides accordingly.

When you index gO_tile[m_idx, n_idx], CUTE evaluates the layout function m_idx * stride[0] + n_idx * stride[1] against the inherited strides. This should be correct — that is exactly what the layout system is designed to do. But if the strides emerging from flat_divide are not what you’d naively expect (for example, because the source tensor was laid out as (seqlen, n_heads, head_dim) instead of (seqlen, head_dim, n_heads)), the writes go to the wrong addresses.

The fingerprint also fits: every output element is wrong, finite, and has a magnitude consistent with “some other valid attention output value just landed here.”

Why this is plausible despite the layout system being designed for this. The reference fmha.py uses the idiomatic partition_C(gO_tile) pattern, which lets the layout machinery handle the per-thread → GMEM mapping the same way it does for acc_pv. K4 uses the explicit-coordinate pattern, which works only if gO_tile’s layout strides match the assumption that the WGMMA C-layout’s (M, N) coords are the right keys.

Diagnostic. Two cheap checks:

  1. Switch to the idiomatic pattern temporarily, just for the output:
    tOgO = pv_thr_mma.partition_C(gO_tile)
    tOgO_mn = cute.make_tensor(tOgO.iterator,
                               self.layout_acc_mn(pv_tiled_mma, tOgO.layout))
    for i, j:
        tOgO_mn[i, j] = (acc_pv_mn[i, j] * inv_l).to(OUTPUT_DTYPE)
    

    If the bug fingerprint changes (in particular, if bad drops dramatically), this is your bug.

  2. Print gO_tile.layout once from one thread. If the strides aren’t (d, 1) or (1, Br) (whichever you expect), the explicit indexing is wrong.
Suspect 3 — tOrP not wrapped through make_fragment_A

The setup. In K4:

tOrP = self.make_acc_into_op(acc_qk, pv_tiled_mma.tv_layout_A, INPUT_DTYPE)
...
cute.gemm(pv_tiled_mma, acc_pv,
          tOrP[(None, None, k_block_idx)],
          tOrV_k[(None, None, k_block_idx)],
          acc_pv)

The reference does the same — acc_qk_fixed = make_acc_into_op(...) then cute.gemm(pv_tiled_mma, acc_pv, acc_qk_fixed, ...). So not wrapping through make_fragment_A is consistent with the reference, which suggests this is fine in principle.

Why this is still worth a glance. The reference passes acc_qk_fixed whole to cute.gemm and lets cute.gemm slice the k-block dimension internally. K4 slices it manually with tOrP[(None, None, k_block_idx)]. This requires tOrP’s layout to actually have a k_block mode in the right position with the right stride — which it should, by construction in convert_c_layout_to_a_layout, but only if the C-layout going in had the right structure to be reshapeable that way.

Why this is lower priority. If this were the bug, you’d typically see a more localized failure (e.g., partial corruption, or specific k_block contributions wrong). The “every element wrong” fingerprint is more consistent with a global layout disagreement than a per-k_block slicing error. Still worth ruling out.

Diagnostic. Match the reference call style: pass tOrP whole and let cute.gemm handle k-block slicing.

pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.gemm(pv_tiled_mma, acc_pv, tOrP, tOrV[(None, None, None, kv_consumer_state.index)], acc_pv)

If this changes the result, the manual k_block slicing was wrong.

Suspect 4 — acc_pv zero-init not visible to WGMMA on iter 0

The setup.

acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape)
acc_pv_mn = cute.make_tensor(acc_pv.iterator, self.layout_acc_mn(pv_tiled_mma, acc_pv.layout))
for i, j:
    acc_pv_mn[i, j] = cutlass.Float32(0.0)
...
# iter 0 of mainloop:
cute.gemm(pv_tiled_mma, acc_pv, tOrP, tOrV, acc_pv)   # ACCUMULATE=True

The PV gemm with ACCUMULATE=True reads the existing register state of acc_pv and adds \(P \cdot V\) to it. For iter 0 to be correct, acc_pv’s registers must hold zero at the moment WGMMA reads them.

Why this is suspicious. Between the zero-init loop and the first PV WGMMA, there is a cute.nvgpu.warpgroup.fence() immediately before the QK gemm — but the fence relevant for PV is the one before the PV gemm. K4 places it in the right spot:

cute.nvgpu.warpgroup.fence()
tOrV_k = tOrV[(None, None, None, kv_consumer_state.index)]
num_k_blocks_pv = cute.size(tOrP, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks_pv):
    pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
    cute.gemm(...)

The fence is between the zero-init (and the softmax computation that reads/writes registers) and the PV gemm. So zero-init should be visible. The risk is more subtle: if the compiler optimizes away the zero-init writes (because it sees acc_pv_mn[i, j] = 0.0 followed by acc_pv += ... and decides to fuse), the registers may never actually be written to zero.

Why this is lower priority. The fingerprint magnitude (max_abs ≈ 32.6) is large compared to typical softmax-normalized output values, which is consistent with garbage-initial-state. But a global layout disagreement (Suspect 1) explains the same magnitude equally well, and the fingerprint coverage (essentially all elements wrong) is more consistent with a structural issue than an iter-0 init issue.

Diagnostic. Replace the manual zero-init with a cute.gemm that has ACCUMULATE=False on iter 0’s first k_block, like the reference does in softmax_step’s wrapper:

if iter == 0 and k_block_idx == 0:
    pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
else:
    pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)

This makes the first WGMMA write the accumulator instead of accumulating into it, which sidesteps the visibility-of-zero question entirely. If the bug disappears with this change, the zero-init was the issue.

Suspect 5 — tOcO mn-view coordinates wrong

The setup.

cO = cute.make_identity_tensor((self.Br, self.d))
tOcO = pv_thr_mma.partition_C(cO)
tOcO_mn = cute.make_tensor(tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout))

tOcO_mn[i, j] is supposed to return the (M, N) tile coord for thread T’s local row \(i\), column \(j\).

Why this is suspicious. layout_acc_mn was designed for value-mode layouts where acc[0] is the per-thread value layout, acc[1] and acc[2] are M and N tile coords. When applied to tOcO (a coordinate tensor partitioned the same way acc_pv is), the structure should match — but only if partition_C(cO) produces the same layout shape as make_fragment_C does. There is no fundamental reason it shouldn’t (both partition through tv_layout_C), but if partition_C of an identity tensor returns a slightly different layout shape (e.g., extra rank-1 modes from broadcasting), layout_acc_mn could split it incorrectly and return an mn-view that maps (i, j) to the wrong tile coordinate.

Why this is lower priority. This pattern is used in the reference fmha.py in essentially the same way (see lines 1146-1149 of the reference). If it were broken for tOcO, the reference would also be broken.

Diagnostic. Print tOcO_mn[0, 0], tOcO_mn[1, 0], tOcO_mn[0, 1] from thread 0. The first should be (0, 0) (or some specific known position depending on the atom). The second should differ in the M coord. The third should differ in the N coord. If the pattern doesn’t make sense, the layout is wrong.

Diagnostic plan: order of operations

The suspects are ranked by likelihood given the fingerprint. A practical sequence:

  1. Suspect 2 first — switch the output scatter to the idiomatic partition_C(gO_tile) pattern. It is the cheapest change (≈10 lines) and rules out the most common bug class.
  2. Suspect 1 next — diff K4’s V SMEM-layout construction against the reference’s. If the reference uses a single helper that derives both the SMEM atom and the MMA major-mode together, replicate that.
  3. Suspect 4 — replace zero-init with ACCUMULATE=False on iter 0’s first PV k_block.
  4. Suspect 3 — match the reference call style for cute.gemm with tOrP passed whole.
  5. Suspect 5 — sanity-check tOcO_mn coordinates with prints.

If after Suspect 2 the bug fingerprint changes substantially (bad drops, even if not to zero), the output scatter was a bug — there may still be others, and the residual fingerprint will guide which suspect to attack next. If Suspect 2 changes nothing, Suspect 1 is by far the most likely remaining cause.