K3 in CuTe [FA3]

This is the in-depth companion to the K3 section of the FA3 worklog. Refer back there for the high-level idea — TMA replaces K2’s strided loads, the producer/consumer pipeline sets up overlap K4 will lean on, and the substages 1 → 2 → 5 are structural preparation for K4. This post covers the mechanics: what TMA actually is, how the pipeline protocol works, what mbarriers do, and the CuTe API for putting it all together.

The kernel file is kernels/k3.py in the project repo.

TMA from first principles

The unit

TMA (Tensor Memory Accelerator) is SM90 hardware that copies tiles of a multidimensional GMEM tensor into SMEM, asynchronously, with hardware-level address calculation. Three properties distinguish it from the SM80 cp.async model:

  1. One thread issues the entire copy. A single thread executes one TMA instruction; hardware walks the source tile, generates addresses, fetches bytes, writes them to SMEM in the destination layout. Compare to cp.async where every thread issues its own per-element load.

  2. The tile is multidimensional and known to hardware. Before the kernel runs, the host builds a tensor descriptor recording the source tensor’s shape, strides, and dtype. The descriptor is an opaque object that the TMA unit consults at issue time. The kernel then references the descriptor and supplies tile coordinates; hardware computes addresses from shape × stride × coord.

  3. Async with completion tracked by an mbarrier. The TMA instruction returns immediately. When the copy completes, hardware arrives on a memory barrier with the byte count it transferred. Consumers wait on the barrier.

TMA supports G2S (load) and S2G (store); K3 uses G2S only. It supports multicast (one TMA broadcasts a tile to multiple CTAs in a cluster); K3 sets num_multicast=1.

Why this matters even when K2 wasn’t bandwidth-bound

Three reasons. (a) WGMMA in K4 requires operands in SMEM with specific (eventually swizzled) layouts; TMA is the natural mechanism to put them there. (b) Even at one stage with no overlap, TMA frees 255 of 256 threads from the load phases; K4 will fill them with WGMMA. (c) TMA can saturate HBM bandwidth from a single issuer more efficiently than 256 threads each issuing per-element loads.

The producer / consumer pattern

This is the general async-loading discipline; TMA is one specialisation of it, and warp specialization in K5 is another.

Without pipelining, the kernel issues a load, waits for completion, consumes the data, and only then issues the next load. The SM is idle whenever it waits.

With pipelining, the kernel keeps \(N\) loads in flight at all times by carving SMEM into a circular buffer of \(N\) stages. While the consumer is reading stage \(j\), the producer is filling stage \(j+1\) (and \(j+2\), …). \(N\) is the latency-hiding budget; it is bounded above by available SMEM and by the number of barriers the kernel manages.

The producer and consumer cycle through stages via a tight protocol. Each stage has two mbarriers:

  • a “full” barrier that the producer signals when the stage’s data has arrived, and the consumer waits on;
  • an “empty” barrier that the consumer signals when it’s done with the stage, and the producer waits on before reissuing into the same slot.

Producer side (per iteration):

producer_acquire(state)   # wait on empty barrier for stage state.index
issue load into stage state.index, with full barrier as completion target
producer_commit(state)    # TMA arrives on full barrier on completion;
                          # for TMA this is a no-op (HW signals directly)
state.advance()           # move to next stage

Consumer side (per iteration):

consumer_wait(state)      # wait on full barrier for stage state.index
use the data at stage state.index
consumer_release(state)   # signal empty barrier for stage state.index
state.advance()           # move to next stage

At steady state with \(N\) stages and a long loop, the producer runs roughly \(N-1\) iterations ahead of the consumer, with the buffer continuously filled at depth \(N\).

For K3.1, \(N = 1\): there is no offset. The producer issues, waits to be done (via the consumer release), and reissues. No overlap. The protocol still runs because K3.2 and K3.3 use the same code with \(N > 1\).

Pipeline state. Both sides carry a small state object with two fields:

  • index — current stage in the circular buffer; cycles through \(0, 1, \ldots, N-1, 0, 1, \ldots\)
  • count — total stages this state has advanced through; monotonic, does not wrap

index answers “which SMEM slot do I touch right now?”; count answers “which KV chunk does that correspond to in GMEM?” Both are needed: the slot determines where in SMEM bytes land; the count determines which tile of GMEM is loaded.

mbarriers

An mbarrier is not a single bit. It’s a small SMEM object with three fields:

  • A byte counter that arrivers increment by the number of bytes transferred.
  • An expected total (tx_count) set at init.
  • A phase bit that flips each time the counter reaches the expected total and resets.

Lifecycle of one “fire”:

  1. Init: tx_count is set; phase is in its initial state.
  2. TMA issues; hardware copies; on completion, hardware arrives on the barrier with its transferred byte count.
  3. When cumulative bytes ≥ tx_count, the phase flips. This is the “fire” event.
  4. consumer_wait blocks until it observes the phase flip. After it returns, the data is fully visible in SMEM.
  5. Consumer eventually arrives on the empty barrier; producer can refill.

Two non-obvious consequences:

Why bytes, not arrivals. The traditional __syncthreads model counts arrivals: \(N\) threads arrived, barrier fires. TMA doesn’t fit that mould because one thread issues a copy that produces many bytes. So the barrier counts what the hardware actually delivers: bytes. This is the “transaction” in transaction barrier. It’s also why TMA needs a different barrier mechanism than __syncthreads(): that primitive is a per-CTA arrive-and-wait on a thread-count, and can’t observe “are the bytes there yet?” because no thread is doing the copy at the software level.

Aggregation across multiple TMAs. If tx_count equals the sum of bytes from several TMAs targeting the same barrier, the barrier only fires when all of them have completed. K3 uses this to tie \(K_j\) and \(V_j\) to the same stage: one barrier, tx_count = sK_bytes + sV_bytes, two TMA copies, both target the same barrier. The barrier fires only when both have arrived. If tx_count were set to only sK_bytes, the barrier would fire after \(K_j\) alone, and the consumer would race against \(V_j\)’s in-flight write.

Why K and V get separate SMEM regions

K2 used one SMEM buffer for K and V, reused inside an iteration: load K, build S, exponentiate to P (which displaces S in sS), overwrite the same buffer with V, do PV. That trick fundamentally relies on the load being synchronous within the iteration: K is finished before V starts, so they can share storage.

In K3, the pipeline can have the producer loading stage \(j+1\)’s K while the consumer is still reading stage \(j\)’s V. The two operations may overlap in time, so they cannot share the same SMEM bytes.

So K3 allocates sK and sV as separate MemRanges in the storage struct, each staged \(N\) deep:

@cute.struct
class SharedStorage:
    sK: cute.struct.Align[
        cute.struct.MemRange[INPUT_DTYPE, cute.cosize(sK_layout_staged)], 1024,
    ]
    sV: cute.struct.Align[
        cute.struct.MemRange[INPUT_DTYPE, cute.cosize(sV_layout_staged)], 1024,
    ]
    ...

They share a stage index and a single barrier (via tx_count = sK_bytes + sV_bytes), but the bytes live in different regions. The total KV-stage cost is therefore \(2 B_c d\) bytes (K + V), not \(B_c d\).

Swizzle: deferred to K4

A naive row-major SMEM layout (stride=(d, 1)) suffers from bank conflicts when many threads of a warp simultaneously read the same column of a row. SMEM has 32 banks of 4 bytes each; addresses that hash to the same bank serialise.

Swizzling is a deterministic permutation of SMEM addresses that scatters the bytes of a logically-contiguous row across banks, so that hardware-fixed access patterns (the ones tensor cores emit) hit distinct banks. From the user’s perspective, you still index a tile as sQ[i, k]; under the hood the layout maps (i, k) to a permuted offset.

K3 keeps SMEM layouts plain row-major (cute.make_layout(..., stride=(d, 1, ...)) for the staged buffers). The reason: swizzle is only meaningful for the tensor-core access pattern, which K3’s scalar FP32 GEMM doesn’t emit. Swizzling here would just permute addresses without reducing conflicts in any pattern we actually issue. K4 brings swizzled SMEM layouts together with WGMMA, where the access pattern is fixed by hardware and swizzle becomes load-bearing.

The CuTe TMA + pipeline API surface

Building a TMA atom

tma_atom_q, tma_tensor_q = cute.nvgpu.cpasync.make_tiled_tma_atom(
    cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(),
    mQ, sQ_layout, (Br, d), num_multicast=1,
)

CopyBulkTensorTileG2SOp selects the variant: G2S (load), single CTA, tile-mode. mQ is the GMEM source; the descriptor records its shape, strides, dtype. sQ_layout is the SMEM destination layout (per stage). (Br, d) is the tile shape per issue.

Returns two things. tma_atom_q is the opaque handle that holds the descriptor; it goes as the first argument to cute.copy(...). tma_tensor_q is a re-wrapped view of mQ with TMA-compatible layout metadata; you must use this from here on. cute.copy from raw mQ will not work.

Prefetching descriptors

if warp_idx == 0:
    cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q)

Pulls the descriptor into SM caches so the first TMA issue doesn’t pay descriptor-fetch latency. Once per atom, on one warp.

Partitioning

tQsQ, tQgQ = cute.nvgpu.cpasync.tma_partition(
    tma_atom_q, 0, cute.make_layout(1),
    cute.group_modes(sQ_full, 0, 2),
    cute.group_modes(gQ_tiles, 0, 2),
)

0 and cute.make_layout(1) are the multicast slot index and CTA layout: “no multicast, this CTA only.” cute.group_modes(t, 0, 2) collapses the first two modes of t into one, because TMA treats the whole tile as a single opaque blob. After grouping, sQ_full has shape (tile_blob, n_stages) and gQ_tiles has shape (tile_blob, tile_coords...). The returned tQsQ and tQgQ are partitioned views ready for cute.copy.

Issuing a copy

cute.copy(
    tma_atom_q,
    tQgQ[(None, bidx_m)],                   # GMEM tile coord
    tQsQ[(None, q_producer_state.index)],   # SMEM stage
    tma_bar_ptr=q_pipeline.producer_get_barrier(q_producer_state),
)

tma_bar_ptr tells TMA which mbarrier to arrive on at completion.

The pipeline object

q_pipeline = pipeline.PipelineTmaAsync.create(
    barrier_storage=storage.q_pipeline_array_ptr.data_ptr(),
    num_stages=q_stage,
    producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
    consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps),
    tx_count=q_bytes,
    cta_layout_vmnk=cute.make_layout((1, 1, 1, 1)),
    defer_sync=True,
)

Field by field:

  • barrier_storage — pointer to the SMEM region that will hold the stage’s mbarriers (two per stage).
  • num_stages — the \(N\) of the circular buffer.
  • producer_group=CooperativeGroup(Agent.Thread)exactly one thread issues TMAs. Matches the hardware: a TMA instruction is issued by a single thread.
  • consumer_group=CooperativeGroup(Agent.Thread, num_warps) — all num_warps × 32 threads are consumers. They all execute the compute body and need to see the loaded data.
  • tx_count — bytes-per-stage; the full barrier fires when this many bytes have arrived. For Q it’s Br * d * sizeof(bf16). For KV it’s sK_bytes + sV_bytes so the barrier waits for both TMAs.
  • cta_layout_vmnk=cute.make_layout((1, 1, 1, 1)) — single CTA, no multicast.
  • defer_sync=True — skips an internal cross-CTA init handshake because we drive pipeline_init_arrive / pipeline_init_wait manually.

Pipeline state objects

q_producer_state = pipeline.make_pipeline_state(PipelineUserType.Producer, num_stages)
q_consumer_state = pipeline.make_pipeline_state(PipelineUserType.Consumer, num_stages)

Two opaque counters per pipeline, one per side. Both start at (index=0, count=0); both advance() after every iteration.

K3.1 walkthrough

The structure: build TMA atoms and pipelines, prefetch descriptors, load Q once, run the KV mainloop with the producer/consumer protocol. The compute body is K2’s, unchanged.

Q load: one shot

if warp_idx == 0:
    q_pipeline.producer_acquire(q_producer_state)
    cute.copy(
        tma_atom_q,
        tQgQ[(None, bidx_m)],
        tQsQ[(None, q_producer_state.index)],
        tma_bar_ptr=q_pipeline.producer_get_barrier(q_producer_state),
    )
    q_pipeline.producer_commit(q_producer_state)
    q_producer_state.advance()

q_pipeline.consumer_wait(q_consumer_state)

Warp 0 acquires the (sole) Q stage, issues one TMA for the M-tile this CTA owns (bidx_m), commits, and advances. All 256 threads then consumer_wait on the full barrier. After this returns, \(Q_i\) is in SMEM and never reloaded for the rest of the kernel.

producer_commit is a no-op for TMA: the hardware itself signals the full barrier on completion. The call is there for API symmetry with cp.async producers.

KV mainloop: two TMAs, one barrier

for j in cutlass.range(n_kv, unroll=1):
    if warp_idx == 0:
        kv_pipeline.producer_acquire(kv_producer_state)
        cute.copy(tma_atom_k,
                  tKgK[(None, j)],
                  tKsK[(None, kv_producer_state.index)],
                  tma_bar_ptr=kv_pipeline.producer_get_barrier(kv_producer_state))
        cute.copy(tma_atom_v,
                  tVgV[(None, j)],
                  tVsV[(None, kv_producer_state.index)],
                  tma_bar_ptr=kv_pipeline.producer_get_barrier(kv_producer_state))
        kv_pipeline.producer_commit(kv_producer_state)
        kv_producer_state.advance()

    kv_pipeline.consumer_wait(kv_consumer_state)
    sK = sK_full[(None, None, kv_consumer_state.index)]
    sV = sV_full[(None, None, kv_consumer_state.index)]

    # K2 compute body (S = QK^T, online softmax, O += PV)

    cute.arch.sync_threads()
    kv_pipeline.consumer_release(kv_consumer_state)
    kv_consumer_state.advance()

Step by step:

  1. Producer acquire. One thread (warp 0’s first) waits on the KV stage’s empty barrier.
  2. Two TMA copies, one barrier. Both cute.copy calls target the same tma_bar_ptr. The barrier’s tx_count = sK_bytes + sV_bytes ensures it fires only once both copies have arrived.
  3. Commit + advance. Bookkeeping for the producer state. Actual completion comes from hardware on the full barrier.
  4. Consumer wait. All 256 threads block on the full barrier. After it returns, \(K_j\) is in sK and \(V_j\) is in sV.
  5. Slice + compute. The kernel slices each stage out (sK_full[..., kv_consumer_state.index]) and runs the K2 compute body unchanged.
  6. Consumer release. All threads arrive on the empty barrier. Producer is now free to reissue.
  7. Advance. Consumer state moves to next stage (wraps to 0 at \(N=1\)).

With \(N = 1\), steps 1-3 of iteration \(j+1\) can only begin after step 6 of iteration \(j\). No overlap.

K3.2: prefetch + steady-state

With \(N > 1\) the protocol splits into two phases.

Prefetch. Before the mainloop, the producer fills the first \(\min(N, n_{kv})\) stages:

prefetch_count = cutlass.min(kv_stage, n_kv)
if warp_idx == 0:
    for _ in cutlass.range(prefetch_count, unroll=1):
        kv_pipeline.producer_acquire(kv_producer_state)
        cute.copy(tma_atom_k, tKgK[(None, kv_producer_state.count)],
                  tKsK[(None, kv_producer_state.index)], tma_bar_ptr=...)
        cute.copy(tma_atom_v, tVgV[(None, kv_producer_state.count)],
                  tVsV[(None, kv_producer_state.index)], tma_bar_ptr=...)
        kv_pipeline.producer_commit(kv_producer_state)
        kv_producer_state.advance()

Note kv_producer_state.count for the GMEM tile index and .index for the SMEM slot. Same state object, two different fields, two different roles.

Steady-state. Each mainloop iteration consumes one stage and the producer reissues into the slot the consumer just released:

for j in cutlass.range(n_kv, unroll=1):
    kv_pipeline.consumer_wait(kv_consumer_state)
    sK = sK_full[(None, None, kv_consumer_state.index)]
    sV = sV_full[(None, None, kv_consumer_state.index)]

    # K2 compute body

    cute.arch.sync_threads()
    kv_pipeline.consumer_release(kv_consumer_state)
    kv_consumer_state.advance()

    if warp_idx == 0 and kv_producer_state.count < n_kv:
        kv_pipeline.producer_acquire(kv_producer_state)
        cute.copy(tma_atom_k, tKgK[(None, kv_producer_state.count)],
                  tKsK[(None, kv_producer_state.index)], tma_bar_ptr=...)
        cute.copy(tma_atom_v, tVgV[(None, kv_producer_state.count)],
                  tVsV[(None, kv_producer_state.index)], tma_bar_ptr=...)
        kv_pipeline.producer_commit(kv_producer_state)
        kv_producer_state.advance()

The kv_producer_state.count < n_kv guard stops the producer from issuing past the end. The last \(N - 1\) iterations are drain: the producer is idle, the consumer finishes the pre-loaded stages.

This is the same code K3.3 uses with kv_stage = 5.

K3.3: why 5 stages

Two bounds determine the depth.

Latency-hiding bound. With \(N\) stages the producer can run \(N - 1\) iterations ahead. TMA round-trip latency on H100 is on the order of hundreds of cycles per tile; one or two stages is already enough to hide the load against K3’s scalar-GEMM consumer. For K3 specifically, 5 stages buys no additional hiding; the consumer is so much slower than the load that even one stage is overkill.

SMEM bound. The total per-CTA SMEM as a function of kv_stage:

\[\text{SMEM} = \text{sQ}_{1\text{ stage}} + \text{kv\_stage} \cdot (\text{sK}_{1\text{ stage}} + \text{sV}_{1\text{ stage}}) + \text{sS}\]

For \(B_r = B_c = 64\), \(d = 128\), BF16 inputs, FP32 sS:

\[\begin{aligned} \text{sQ}_{1\text{ stage}} &= 64 \times 128 \times 2 = 16\text{ KB} \\ \text{sK}_{1\text{ stage}} &= 64 \times 128 \times 2 = 16\text{ KB} \\ \text{sV}_{1\text{ stage}} &= 64 \times 128 \times 2 = 16\text{ KB} \\ \text{sS} &= 64 \times 64 \times 4 = 16\text{ KB} \end{aligned}\]

Per-CTA totals across substages:

Substage kv_stage sK + sV total Grand total
K3.1 1 32 KB 64 KB
K3.2 2 64 KB 96 KB
K3.3 5 160 KB 192 KB

K3.3 at 192 KB fits in H100’s 228 KB/SM with comfortable margin, but only one CTA per SM (two CTAs would need ≤ 114 KB each). Beyond 5 stages the SMEM bound starts cutting into other things; below 5 K4 loses overlap headroom for the much faster WGMMA consumer.

Why we settle on 5 for K3 even though K3 doesn’t benefit. K3.3 is a structural rehearsal. The pipeline depth K4 uses is 5; running it at depth 5 in K3 verifies the protocol works at that depth with the simpler consumer, before WGMMA enters the picture. K3 is plumbing; K4 is where the plumbing earns its cost.

For the next kernel in the worklog, see K4: WGMMA.