K4 in CuTe [FA3]
This is the in-depth companion to the K4 section of the FA3 worklog. Refer back there for the high-level idea — WGMMA replaces both scalar GEMMs, K4.1 launders through SMEM, K4.2 keeps everything in registers. This post covers the mechanics: WGMMA from first principles, the CuTe API surface, the K4.1 walkthrough, TV layouts properly, the K4.2 register softmax and register handoff, and the V mode-order fix.
Two kernel files: kernels/k4_qk.py (K4QKFMHA) and kernels/k4.py (K4FMHA) in the project repo.
WGMMA from first principles
The unit
WGMMA (Warp-Group Matrix Multiply-Accumulate) is the SM90 warpgroup-level MMA instruction. One issue computes
\[C \mathrel{+}= A \cdot B\]at warpgroup granularity (128 threads, 4 warps), with hardware doing all of the multiply-add work on the tensor cores. Three defining properties:
- Warpgroup-collective. All 128 threads of a warpgroup must issue the same WGMMA together. Hardware reads operand fragments from those 128 threads’ registers and writes the output fragment back to the same 128 threads.
- Asynchronous. WGMMA returns immediately. The math actually happens in the tensor-core pipeline; you have to fence and wait when you want to read the result.
- Atom-shaped. The instruction operates on a fixed shape per issue, \(M \times N \times K\). For SM90 BF16, the atom is \(64 \times N \times 16\) with \(N \in \{8, 16, \ldots, 256\}\). Larger matmuls are built by iterating WGMMA calls over the K axis (the “k_block” loop).
Atom shape vs per-thread fragment
The atom shape (say \(64 \times 64 \times 16\) for our QK config) is the shape of the warpgroup-level tile. One WGMMA call produces \(64 \times 64 = 4096\) output elements cooperatively across 128 threads. Each thread therefore holds \(4096 / 128 = 32\) output elements in its registers. Those 32 elements are not a clean sub-rectangle of the tile; they are scattered across the tile in a hardware-defined pattern.
The pattern is described by the C TV layout (tv_layout_C), which we’ll unpack properly when K4.2 needs it. For K4.1 the only thing you need to know is: the C-fragment is per-thread, the 32 elements are at specific tile positions decided by hardware, and you read them via a coordinate tensor that tells you which tile position each fragment slot maps to.
The same scattering applies to A and B operands when they come from registers. K4.1 takes both A (Q) and B (K) from SMEM, so the per-thread A/B fragments are layout reinterpretations rather than real register holdings, and we don’t have to think about them. K4.2 changes this for PV’s A operand.
Operand sources and the 8 variants
WGMMA constrains where its operands can come from:
- A can come from SMEM or registers (RF).
- B can come only from SMEM.
- C is always the register accumulator (RF).
For each SMEM operand, the leading axis (the axis that’s contiguous in memory) can be either K-major or MN-major. With 2 choices for A source × 2 for A major × 2 for B major, the WGMMA hardware exposes 8 distinct variants per dtype and shape. Swizzle isn’t a separate axis; each major mode admits a small set of swizzle patterns, encoded into the SMEM descriptor at issue time.
K4.1 uses A from SMEM (K-major), B from SMEM (K-major) for QK. Both Q and K have \(d\) contiguous in SMEM, and \(d\) is the contraction axis of QK, so K-major is natural for both.
SMEM descriptors
When A or B come from SMEM, the WGMMA instruction reads them via a 64-bit SMEM descriptor: a packed value encoding the operand tile’s start address, swizzle pattern, leading-dim stride, and a few flags. The descriptor describes the operand tile for one WGMMA call (one k_block); for a full-K matmul we issue multiple WGMMAs, each with a descriptor pointing at the next K-slice.
You never construct these descriptors by hand. The tiled MMA object plus the SMEM layout produced by the sm90 helpers carry enough information that cute.gemm builds descriptors at issue time.
The async pattern: fence / commit_group / wait_group
Three primitives at three different roles:
fence() is a register fence. It does not wait. It tells the WGMMA pipeline that the current state of register operands is final and may be consumed by the next WGMMA issue. Place it before the WGMMAs that read those registers.
commit_group() is a tagging operation. It does not wait. It bundles all WGMMAs issued since the last commit into a named group (“call this batch group #N”). You issue, you commit, you wait on the group.
wait_group(N) is the one that actually blocks. It waits until at most \(N\) committed groups remain in flight. wait_group(0) blocks until all prior commits have completed.
The argument \(N\) exists for pipelining: K6 will issue group A, commit, issue group B, commit, then wait_group(1) (“wait for A but let B keep running”). K4 doesn’t pipeline at the WGMMA level and uses wait_group(0).
Standard pattern:
fence() # register operands ready for WGMMA
for k_block in K-blocks:
cute.gemm(tiled_mma, C, A_k, B_k, C) # issue WGMMAs (async)
commit_group() # close this batch
wait_group(0) # wait for it
CuTe’s WGMMA wrapping
Tiled MMA (sm90_utils.make_trivial_tiled_mma) bundles a WGMMA atom with the TV layouts for A, B, C.
qk_tiled_mma = sm90_utils.make_trivial_tiled_mma(
INPUT_DTYPE, INPUT_DTYPE, # A, B dtypes (BF16)
warpgroup.OperandMajorMode.K, # A major mode
warpgroup.OperandMajorMode.K, # B major mode
ACC_DTYPE, # accumulator dtype (FP32)
(1, 1, 1), # atom layout (one atom per warpgroup)
(Br, Bc), # M, N of the tile
warpgroup.OperandSource.SMEM, # A source (B is implicitly SMEM)
)
Thread MMA (tiled_mma.get_slice(tidx)) is a thread-local view that knows how to extract this thread’s slice of A, B, C from the warpgroup-level tile.
partition_A / partition_B / partition_C take a source tile (SMEM, GMEM, or a coordinate tensor) and return a tensor whose layout is this thread’s slice under the relevant operand’s TV layout. View, not copy.
make_fragment_A / make_fragment_B / make_fragment_C wrap a partitioned tensor in the layout that the WGMMA instruction encoding expects.
- For SMEM-sourced operands, this is a layout reinterpretation. No register allocation.
- For register-sourced operands, this allocates registers in the per-thread fragment layout.
- For C, this always allocates: C is always in registers.
cute.gemm(tiled_mma, D, A, B, C) issues one WGMMA: D = A * B + C, with C and D typically the same accumulator. The flag ACCUMULATE on the tiled MMA controls whether the hardware behaviour is “C = AB” or “C += AB”.
Naming convention for partitioned tensors: t<Op><mem><Operand>, where Op is S (QK matmul, output \(S\)) or O (PV matmul, output \(O\)), mem is s/r/g/c (SMEM/registers/GMEM/coordinates), and Operand is the operand name. tSsQ is “QK matmul, SMEM, Q operand”; tSrK is “QK matmul, register-fragment view, K operand.”
Swizzled SMEM layouts
K3 used plain row-major SMEM. K4 needs swizzled layouts to feed the tensor cores efficiently: the access pattern WGMMA emits when reading from SMEM hashes to specific banks under the swizzle, and the wrong layout produces conflicts that serialize the load. CuTe builds the right swizzle via sm90_utils.make_smem_layout_a / make_smem_layout_b:
sQ_layout_staged = sm90_utils.make_smem_layout_a(
q_layout_enum, mma_tile_mnk, INPUT_DTYPE, q_stage,
)
sK_layout_staged = sm90_utils.make_smem_layout_b(
k_layout_enum, mma_tile_mnk, INPUT_DTYPE, kv_stage,
)
The helpers return a ComposedLayout: a normal layout (.outer) plus a swizzle function (.inner). At SMEM allocation time we pass both:
sQ_full = storage.sQ.get_tensor(
sQ_layout_staged.outer, swizzle=sQ_layout_staged.inner,
)
From here on, indexing sQ_full[i, j] returns the value at the swizzled offset; you don’t see the permutation, but the tensor cores do. The TMA atom and the WGMMA descriptor both derive from this single layout object, so producer and consumer agree by construction.
K4.1 walkthrough
The QK matmul
Setup before the mainloop:
qk_thr_mma = qk_tiled_mma.get_slice(tidx)
tSsQ = qk_thr_mma.partition_A(sQ_full) # (MMA, MMA_M, MMA_K, q_stage)
tSsK = qk_thr_mma.partition_B(sK_full) # (MMA, MMA_N, MMA_K, kv_stage)
tSrQ = qk_tiled_mma.make_fragment_A(tSsQ) # layout reinterpretation
tSrK = qk_tiled_mma.make_fragment_B(tSsK) # layout reinterpretation
qk_acc_shape = qk_thr_mma.partition_shape_C((Br, Bc))
partition_A and partition_B produce per-thread views into the SMEM-staged tiles. The mode structure (MMA, MMA_M, MMA_K, q_stage) means “per-WGMMA-issue chunk, M-tile coord, K-block coord, pipeline stage.” The K-block coord is what we iterate over inside the WGMMA loop.
partition_shape_C returns the per-thread C fragment’s shape but doesn’t allocate. The allocation happens inside the mainloop, fresh per KV iteration.
Inside the mainloop:
acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape)
cute.nvgpu.warpgroup.fence()
tSrQ_k = tSrQ[(None, None, None, q_consumer_state.index)]
tSrK_k = tSrK[(None, None, None, kv_consumer_state.index)]
num_k_blocks = cute.size(tSrQ_k, mode=[2]) # = d / 16
for k_block_idx in range_constexpr(num_k_blocks):
qk_tiled_mma.set(
cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0,
)
cute.gemm(
qk_tiled_mma, acc_qk,
tSrQ_k[(None, None, k_block_idx)],
tSrK_k[(None, None, k_block_idx)],
acc_qk,
)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(0)
Five things going on:
1. make_fragment_C allocates the per-thread C fragment as 32 FP32 registers (for the 64×64 atom). Initial register state is undefined.
2. fence() marks the register state as ready for WGMMA consumption. Strictly not required here (the first WGMMA will overwrite C because ACCUMULATE=False), but the canonical pattern has the fence before the gemm loop.
3. Slice down to this iteration’s stages. The stage axis of tSrQ/tSrK gets fixed; the remaining mode is the k_block axis.
4. The k_block loop. The atom’s K axis is 16 BF16 elements wide. For QK contraction length \(d\), we issue d / 16 WGMMAs in sequence, each accumulating into the same acc_qk. The ACCUMULATE flag is the trick: False on the first issue (write the accumulator), True on the rest (add to it). For \(d = 64\): 4 WGMMAs; for \(d = 128\): 8.
5. commit_group() + wait_group(0) close the batch and wait. After wait_group(0) returns, acc_qk holds the per-thread fragment of \(S = QK^\top\) in FP32.
The coord-tensor scatter to sS
acc_qk is a register fragment under the WGMMA C TV layout. To write it back to a plain \((B_r, B_c)\) SMEM tile, we use an identity tensor: a logical tile that, when indexed, returns the coordinate tuple itself.
cS = cute.make_identity_tensor((Br, Bc)) # cS[m, n] = (m, n)
tScS = qk_thr_mma.partition_C(cS) # partition the same way as C
cS carries no real storage. After partition_C(cS), tScS is the per-thread view of cS under exactly the same TV layout that acc_qk lives under. So for each fragment slot of acc_qk, the corresponding slot of tScS tells you what \((m, n)\) tile coordinate that slot represents.
Indexing both under their mn-view makes the scatter mechanical:
acc_qk_mn = cute.make_tensor(
acc_qk.iterator, layout_acc_mn(qk_tiled_mma, acc_qk.layout),
)
tScS_mn = cute.make_tensor(
tScS.iterator, layout_acc_mn(qk_tiled_mma, tScS.layout),
)
for i in range_constexpr(cute.size(acc_qk_mn, mode=[0])):
for j in range_constexpr(cute.size(acc_qk_mn, mode=[1])):
m_idx = tScS_mn[i, j][0]
n_idx = tScS_mn[i, j][1]
val = acc_qk_mn[i, j] * scale
if const_expr(causal):
q_abs = bidx_m * Br + m_idx
kv_abs = j_outer * Bc + n_idx
if kv_abs > q_abs:
val = -inf
sS[m_idx, n_idx] = val
cute.arch.sync_threads()
layout_acc_mn produces an mn-view of the per-thread fragment (mode 0 walks M-direction slots, mode 1 walks N-direction slots; details later in this post). For K4.1 the interpretation is: each thread sweeps (i, j) over its own fragment, asks tScS_mn for the actual tile coordinate, and writes the scaled value to sS at that coordinate.
After the scatter and the sync, sS contains the full \(S = QK^\top \cdot \text{scale}\) tile in row-major SMEM, identical in layout to K3’s sS. The K3 softmax and the K3 hand-rolled PV run against it unchanged. The price is a 16 KB SMEM round-trip per iteration (\(64 \times 64\) FP32). K4.2 removes that round-trip, which is what forces it to deal with the WGMMA C-layout directly.
TV layouts, properly (K4.2 needs these)
A TV layout is a CuTe layout whose two top-level modes are interpreted as \((T, V)\): Thread index and Value index within that thread. It’s a function
\[(T,\; V) \;\longrightarrow\; \text{tile\_position}\]that says, for every thread \(T\) and every value slot \(V\) in that thread’s fragment, which position in the underlying tile that pair refers to. Mechanically it’s still a CuTe layout (shape and stride); the convention is just that the first mode counts threads and the second counts values per thread.
Every tiled MMA exposes three:
tiled_mma.tv_layout_A
tiled_mma.tv_layout_B
tiled_mma.tv_layout_C
The crucial property of WGMMA C-layouts on SM90: both the Thread mode and the Value mode contribute to both M and N of the tile. Some sub-bits of the thread index walk along M; others walk along N. Same for the value index. This is why a single thread’s fragment is not a clean sub-rectangle of the tile.
Concrete numbers: SM90 BF16, 64×64 C tile
For our QK atom (M=N=64, K=16), the C TV layout shape resolves roughly to:
tv_layout_C.shape = ( (4, 32), ((2, 2), 8) )
^^^^^^^^ ^^^^^^^^^^^^
Threads (T) Values (V)
128 threads (\(4 \times 32\)) each holding 32 values (\(2 \times 2 \times 8\)). Per thread:
| Quantity | Value | Meaning |
|---|---|---|
| Total threads | 128 | one warpgroup |
| Total tile elements | 4096 | \(64 \times 64\) |
| Values per thread | 32 | \(4096 / 128\) |
| Distinct M rows / thread | 2 | mn-view mode 0 |
| Distinct N cols / thread | 16 | mn-view mode 1 |
| Threads sharing a row | 4 | for cross-thread row max / sum |
| Threads on different rows | 32 | \(128 / 4\) |
Two things from this table drive the K4.2 softmax:
- Each thread owns 2 distinct M rows, not one. Per-row state arrays (
s_max,a_sum,s_max_prev) must have length 2, not be a single scalar. - Each M row is shared by 4 threads. Row max and row sum reduce across those 4 sibling threads.
The stride-versus-M test
To peel apart “this sub-mode walks along M” from “this sub-mode walks along N” inside a hierarchical Thread or Value mode, CuTe uses a simple test. Picture the tile flattened M-major into a 1D buffer of \(M \cdot N\) entries. Addresses \(0 \ldots M-1\) form M-row 0, \(M \ldots 2M-1\) form M-row 1, and so on. Then:
- A stride \(< M\) moves within a single row → walks along M.
- A stride \(\geq M\) moves across rows → walks along N.
Two K4.2 utilities use this test: layout_acc_mn (peels the value mode) and reduction_target_n (peels the thread mode). Both call into a shared helper, layout_separate.
The K4.2 helpers
layout_separate
@staticmethod
def layout_separate(thr, src, ref):
lt = cute.make_layout(())
ge = cute.make_layout(())
for k, v in enumerate(ref):
if cutlass.const_expr(v < thr):
lt = cute.append(lt, src[k])
else:
ge = cute.append(ge, src[k])
# ... rank-1 packaging
Given a layout src and a parallel list of strides ref, it splits src’s sub-modes into two buckets: those whose corresponding stride in ref is \(<\) thr go into lt (M-walking part), the rest go into ge (N-walking part). In K4.2 thr is always tiled_mma.shape_mnk[0] = M.
layout_acc_mn
The per-thread C fragment naturally has shape (value_mode, M_tile_coord, N_tile_coord). We want an mn-view where mode 0 walks distinct M rows owned by this thread and mode 1 walks distinct N columns.
@cute.jit
def layout_acc_mn(self, tiled_mma, acc):
separated = self.layout_separate(
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
)
V_M = separated[0]
V_N = separated[1]
# compose with acc[1] (M-tile) and acc[2] (N-tile)
return cute.append(V_M_with_tile, V_N_with_tile)
The reference strides are tv_layout_C.stride[1] — the value mode’s strides (the [1] selects the value side of TV: shape = (T, V), so [0] is threads and [1] is values). After the split, the fragment is reorganised into (M-walking, N-walking) shape.
Indexing the result acc_mn[i, j]: \(i\) walks only along M-bits → distinct M rows of the tile; \(j\) walks only along N-bits → distinct N columns. cute.size(acc_mn, mode=[0]) gives the number of distinct M rows this thread owns: 2 for the 64×64 atom.
n_rows = cute.size(acc_pv_mn, mode=[0]) # 2
n_cols_pv = cute.size(acc_pv_mn, mode=[1]) # 16 for d=64, 32 for d=128
n_cols_qk = cute.size(acc_qk_mn, mode=[1]) # 16 for Bc=64
reduction_target_n
@cute.jit
def reduction_target_n(self, tiled_mma):
separated = self.layout_separate(
tiled_mma.shape_mnk[0],
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
tiled_mma.tv_layout_C.stride[0],
)
return separated[1] # the N-walking part of the thread mode
Same layout_separate machinery applied to the thread mode. Returns the part of the thread index whose stride is \(\geq M\) — the bits that walk along N. Threads that differ only in those bits are siblings on the same row.
For our 64×64 atom, this returns a 1-rank layout with shape [4]: 4 sibling threads per row. The K4.2 reduction iterates over the sub-modes and does a sub-warp shuffle whose group size is each sub-mode’s extent:
reduction_target_qk = self.reduction_target_n(qk_tiled_mma)
red_rank_qk = cute.rank(reduction_target_qk)
for r in range_constexpr(red_rank_qk):
for i in range_constexpr(n_rows):
s_max[i] = cute.arch.warp_reduction_max(
s_max[i], threads_in_group=reduction_target_qk.shape[r],
)
Deriving the group size from reduction_target_n rather than hardcoding 4 means the same softmax works for any atom shape CuTe gives us.
The K4.2 register softmax
Per-row state lives in register tensors of length n_rows:
s_max_layout = cute.make_layout(n_rows)
s_max = cute.make_rmem_tensor_like(s_max_layout, cutlass.Float32)
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
s_max_prev = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
for i in range_constexpr(n_rows):
s_max[i] = -inf
a_sum[i] = 0.0
acc_pv is allocated once before the mainloop, zero-initialised, and accumulates across all KV iterations:
acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape)
acc_pv_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(pv_tiled_mma, acc_pv.layout),
)
for i in range_constexpr(n_rows):
for j in range_constexpr(n_cols_pv):
acc_pv_mn[i, j] = 0.0
The zero-init makes the unified rescale benign on iter 0: with s_max_prev = -inf and s_max finite, scale_pv = exp(-inf) = 0, and 0 * 0 = 0, so the rescale silently does nothing on the first iteration.
Inside the mainloop, after QK has produced acc_qk:
acc_qk_mn = cute.make_tensor(
acc_qk.iterator, layout_acc_mn(qk_tiled_mma, acc_qk.layout),
)
# 1. Scale
for i in range_constexpr(n_rows):
for j in range_constexpr(n_cols_qk):
acc_qk_mn[i, j] *= scale
# 2. Causal mask
if const_expr(causal):
for i in range_constexpr(n_rows):
for j in range_constexpr(n_cols_qk):
q_abs = q_block_offset + tScS_mn[i, j][0]
kv_abs = kv_block_offset + tScS_mn[i, j][1]
if kv_abs > q_abs:
acc_qk_mn[i, j] = -inf
# 3. Save prev row max, new local row max, cross-thread reduce
for i in range_constexpr(n_rows):
s_max_prev[i] = s_max[i]
for j in range_constexpr(n_cols_qk):
s_max[i] = fmax(s_max[i], acc_qk_mn[i, j])
for r in range_constexpr(red_rank_qk):
for i in range_constexpr(n_rows):
s_max[i] = warp_reduction_max(
s_max[i], threads_in_group=reduction_target_qk.shape[r],
)
# 4. Per-row rescale of acc_pv and a_sum
for i in range_constexpr(n_rows):
scale_pv = exp(s_max_prev[i] - s_max[i])
a_sum[i] *= scale_pv
for j in range_constexpr(n_cols_pv):
acc_pv_mn[i, j] *= scale_pv
# 5. P = exp(S - s_max) in-place into acc_qk, accumulate local row sum
for i in range_constexpr(n_rows):
for j in range_constexpr(n_cols_qk):
p = exp(acc_qk_mn[i, j] - s_max[i])
acc_qk_mn[i, j] = p
a_sum[i] += p
Three things to note:
- Same FA2 recurrence as K2/K3, expressed against the per-thread fragment layout.
- Step 5 overwrites
acc_qkin place with \(P\). After this pointacc_qkis the (unnormalised) attention weights, not the score block. - The
a_sumcross-thread reduction is deferred to the finalize stage. Each thread accumulatesa_sum[i]over only the N columns it owns; the full row sum is computed once after all KV iterations.
The QK→PV register handoff: make_acc_into_op
acc_qk holds \(P\) in FP32 under the WGMMA C TV layout. PV wants its A operand in registers, in BF16, under the WGMMA A TV layout. The handoff has to do three things: cast FP32 → BF16, reinterpret the register layout C → A, and avoid moving bytes between threads.
@cute.jit
def make_acc_into_op(self, acc, operand_layout_tv, Element):
operand = cute.make_rmem_tensor_like(
self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
Element,
)
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
acc_vec = acc.load()
operand_as_acc.store(acc_vec.to(Element))
return operand
Step 1: allocate an A-layout register fragment in BF16.
operand = cute.make_rmem_tensor_like(
convert_c_layout_to_a_layout(acc.layout, ...),
Element, # BF16
)
convert_c_layout_to_a_layout reshapes the C-fragment layout into the A-fragment layout that physically aligns with how WGMMA expects A. This works because the SM90 WGMMA register layouts for C and A are designed to be reinterpretable into one another without inter-thread shuffles for BF16, precisely to make QK→PV through registers efficient.
@staticmethod
def convert_c_layout_to_a_layout(c, a):
return cute.make_layout(
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
stride=(c.stride[0], c.stride[1],
(c.stride[2], cute.size(a, mode=[2]) * c.stride[0][2])),
)
This is a CUTLASS C++ idiom ported into Python. The precise arithmetic isn’t load-bearing; the takeaway is that the output describes an A-shaped layout pointing at the same physical registers as the input C-shaped layout.
Step 2: build an aliasing view.
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
Same iterator (same registers as operand), C-shaped layout (same as acc). One register region, two views.
Step 3: cast and store via the C view.
acc_vec = acc.load() # read all of acc (FP32, C-layout)
operand_as_acc.store(acc_vec.to(Element)) # write back as BF16 via the C view
Reading is C-shaped (matches acc); writing routes the FP32→BF16 cast into the registers that will be read by WGMMA under the A-layout. The cast happens in registers, no SMEM, no shuffles.
FP8 caveat. The “no inter-thread shuffles” property holds for BF16. FP8 breaks it because the per-element bit width changes the per-thread register packing, and the cleanest BF16 reinterpretation no longer maps the right bytes to the right threads. K7 needs extra shuffles in
make_acc_into_op.
V’s mode-order trap
\(V\) stores the same physical bytes as \(K\): row-major \((N, d, B)\) with \(d\) contiguous. But \(V\) plays a different role in PV than \(K\) plays in QK, and this changes which axis is “the contraction axis.”
For QK, \(K\) is the B operand and the contraction is \(d\). \(K\)’s contiguous axis is \(d\), which is QK’s K axis. So \(K\) is K-major.
For PV, \(V\) is the B operand and the contraction is \(B_c\). \(V\)’s contiguous axis is still \(d\), but for PV that’s the N axis, not the K axis. So \(V\) needs to be MN-major for PV.
WGMMA needs the right major-mode label on its B operand, and the right matching SMEM layout. The cleanest way to get both ends to agree is to re-view \(V\) at kernel entry with mode order \((d, k, l)\) instead of \((k, d, l)\) — the bytes are unchanged, but CuTe now sees \(d\) as the leading axis:
mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, [1, 0, 2]))
After the re-view, LayoutEnum.from_tensor(mV) returns COL_MAJOR and sm90_mma_major_mode() returns MN. The SMEM helper builds an MN-major layout, the PV tiled MMA picks the right WGMMA variant, and the TMA atom loads \(V\) into the right SMEM shape with tile \((d, B_c)\). Same memory bytes throughout; only the labels change.
Building the PV tiled MMA
pv_tiled_mma = sm90_utils.make_trivial_tiled_mma(
INPUT_DTYPE, INPUT_DTYPE,
warpgroup.OperandMajorMode.K, # A = P, K-major
v_layout_enum.sm90_mma_major_mode(), # B = V, MN-major from the re-view
ACC_DTYPE,
(1, 1, 1),
(Br, d), # M, N of the tile (N is now d)
warpgroup.OperandSource.RMEM, # A comes from registers
)
Two things changed relative to the QK tiled MMA:
- M/N of the tile is
(Br, d)instead of(Br, Bc). PV result is \((B_r, d)\), matching \(O_i\). -
OperandSource.RMEMfor A. WGMMA reads A from registers; we’ll pass the BF16 fragment returned bymake_acc_into_opdirectly.
The PV SMEM layout for V:
pv_mma_tile_mnk = (Br, d, Bc)
sV_layout_staged = sm90_utils.make_smem_layout_b(
v_layout_enum, pv_mma_tile_mnk, INPUT_DTYPE, kv_stage,
)
The (N, K) of B for this matmul is (d, Bc), which matches our re-viewed \(V\)’s (d, k, l) shape. The TMA atom for V correspondingly takes tile shape (d, Bc).
The PV WGMMA
tOrP = self.make_acc_into_op(acc_qk, pv_tiled_mma.tv_layout_A, INPUT_DTYPE)
cute.nvgpu.warpgroup.fence()
tOrV_k = tOrV[(None, None, None, kv_consumer_state.index)]
num_k_blocks_pv = cute.size(tOrP, mode=[2])
for k_block_idx in range_constexpr(num_k_blocks_pv):
pv_tiled_mma.set(Field.ACCUMULATE, True)
cute.gemm(
pv_tiled_mma, acc_pv,
tOrP[(None, None, k_block_idx)],
tOrV_k[(None, None, k_block_idx)],
acc_pv,
)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(0)
Three differences from K4.1’s QK call:
ACCUMULATE is always True. acc_pv is the running attention output across all KV iterations. Every PV WGMMA adds to whatever acc_pv already holds. There’s no “first call writes the accumulator” case for PV.
A comes from registers. tOrP is the BF16 A-layout fragment built via make_acc_into_op. The k_block axis lives in tOrP’s mode 2: num_k_blocks_pv = cute.size(tOrP, mode=[2]). For \(B_c = 64\) and atom \(K = 16\), that’s 4 k_blocks.
The PV fence applies to the A operand. The QK fence was largely cosmetic; this one is functional, marking the registers in tOrP as ready for WGMMA after the softmax and make_acc_into_op cast.
Finalize
Two pieces.
Cross-thread a_sum reduce. Throughout the mainloop each thread accumulated a_sum[i] over only the N columns it owns. The actual row sum requires combining with the 3 sibling threads. Done once, at the end, using reduction_target_n on the PV tiled MMA:
reduction_target_pv = self.reduction_target_n(pv_tiled_mma)
red_rank_pv = cute.rank(reduction_target_pv)
for r in range_constexpr(red_rank_pv):
for i in range_constexpr(n_rows):
a_sum[i] = cute.arch.warp_reduction_sum(
a_sum[i], threads_in_group=reduction_target_pv.shape[r],
)
Output scatter with a coord tensor. acc_pv is in registers under the PV C TV layout. Same fix as K4.1’s acc_qk: probe with an identity tensor partitioned through partition_C, write to GMEM at the per-thread tile coordinates.
cO = cute.make_identity_tensor((Br, d))
tOcO = pv_thr_mma.partition_C(cO)
tOcO_mn = cute.make_tensor(
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout),
)
for i in range_constexpr(n_rows):
inv_l = cute.arch.rcp_approx(a_sum[i])
if a_sum[i] == 0.0 or a_sum[i] != a_sum[i]:
inv_l = 1.0
for j in range_constexpr(n_cols_pv):
m_idx = tOcO_mn[i, j][0]
n_idx = tOcO_mn[i, j][1]
gO_tile[m_idx, n_idx] = (acc_pv_mn[i, j] * inv_l).to(OUTPUT_DTYPE)
Per thread: for each of its 2 rows, compute \(1 / \ell_i\), then for each of its n_cols_pv owned N columns look up the actual \((M, N)\) tile coordinate from tOcO_mn and write the divided value. With 128 threads each writing \(n_\text{rows} \cdot n_\text{cols\_pv}\) slots, every position in the \(B_r \times d\) output tile is written exactly once.
New CuTe primitives in K4
WGMMA-specific:
-
sm90_utils.make_smem_layout_a/make_smem_layout_b— swizzled SMEM layouts. -
sm90_utils.make_trivial_tiled_mma— bundle a WGMMA atom with TV layouts. -
tiled_mma.get_slice(tidx)— thread-local view of the warpgroup tile. -
thr_mma.partition_A/partition_B/partition_C— per-thread views under the operand’s TV layout. -
tiled_mma.make_fragment_A/B/C— wrap a partitioned tensor in the WGMMA-expected layout. -
thr_mma.partition_shape_C— per-thread C fragment shape. -
cute.gemm(tiled_mma, D, A, B, C)— issue one WGMMA. -
tiled_mma.set(Field.ACCUMULATE, bool)— control add-or-overwrite. -
cute.nvgpu.warpgroup.fence/commit_group/wait_group(n)— the async issue protocol.
Layout and register utilities:
-
cute.make_identity_tensor((M, N))— coordinate-reporting tile. -
cute.make_rmem_tensor_like(layout, dtype)— register tensor with a given layout and dtype. -
cute.select(layout, indices)— re-view a tensor with a different mode order (used for the V transpose). -
cute.append,cute.make_layout(())— building blocks forlayout_separate. -
tensor.load()/tensor.store(vec)— bulk read/write of a register tensor as a single value vector. -
cute.rank(layout)— number of top-level modes. -
cute.size(tensor, mode=[k])— extent of mode \(k\).
Project-level helpers built on top of the above and reused by K5+:
-
layout_separate,layout_acc_mn,reduction_target_n,convert_c_layout_to_a_layout,make_acc_into_op.
For the next kernel in the worklog, see K5: Warp specialization.