CuTe DSL fundamentals and primitives [FA3]
What CuTe is: CUTLASS, templates, and an embedded DSL
CuTe stands for “CUDA Tensor” and is the layout/tensor algebra at the heart of CUTLASS 3.x (the version we’re using). CUTLASS is NVIDIA’s open-source library of high-performance GEMM and convolution primitives. Before getting to what CuTe gives us, it helps to understand what CUTLASS is built on, and why.
CUTLASS is a C++ template library. Templates here are doing one specific job: they let tile shapes, layouts, dtypes, and instruction variants be compile-time parameters, so the compiler can specialize a single generic kernel across many concrete shapes without runtime cost. A handwritten GEMM kernel hardcodes its tile sizes; a templated one accepts them as types and lets the compiler emit one specialized SASS per instantiation. The same logic, recompiled for \(bM = 128, bN = 128, bK = 64\) on FP16 versus \(bM = 64, bN = 256, bK = 32\) on FP8, produces two completely different machine-code kernels, each as efficient as a hand-tuned version, with no runtime branching on shape. That’s the win templates buy: zero-cost generic programming.
The cost: CUTLASS C++ code is hard to read. Layouts, strides, swizzles, and instruction variants all live in template parameters, often deeply nested. A simple “load this tile from SMEM and multiply it” can become a several-line template instantiation. This works fine when you’re writing kernels professionally, but it’s a steep wall for everyone else.
That’s where the DSL comes in. A DSL (Domain-Specific Language) is a small language designed for one specific domain, with primitives that match the domain’s vocabulary directly. SQL is a DSL for relational queries; regex is a DSL for string patterns. The advantage over a general-purpose language is that DSL programs are short and read like the domain they describe, instead of being buried under host-language plumbing.
A DSL can be standalone (its own parser, compiler, runtime) or embedded in a host language, where the DSL’s “syntax” is just a particular use of the host language’s existing constructs. Embedded DSLs are common because they ride on the host’s tooling (editor support, packaging, debugging) for free.
CuTe Python DSL is an embedded DSL inside Python. The DSL lives as a Python library: when you write cute.make_layout(shape=(4, 3), stride=(1, 4)) or decorate a function with @cute.jit, you’re not in a separate language. You’re using Python objects whose meaning is defined by CuTe. The DSL provides a single, uniform vocabulary for talking about how data is laid out in any memory space (GMEM, SMEM, registers) and how it gets partitioned across threads and warps for Tensor-Core instructions. Raw CUDA forces you to track all of this with bare pointers and hand-written indexing; CuTe makes the Layout a first-class object that the compiler can reason about.
How does Python code end up running on a GPU? Two compilation models are worth distinguishing first.
AOT (Ahead-Of-Time) compilation is what happens with a normal C++ or CUDA C++ program: you run a compiler once, before the program starts, and it produces a finished binary. The compiler sees the source code, optimizes it, emits machine code, and that’s it. When the program runs later, there’s no compilation step in sight, you’re just executing pre-built instructions. Standard CUDA kernels work this way: nvcc compiles your .cu files into a binary that contains PTX and SASS, ready to launch.
JIT (Just-In-Time) compilation is what happens with languages like Python, Java, or JavaScript: the program starts running first, and a compiler kicks in during execution to turn parts of the program into machine code on demand. The compilation step is real, it does the same kind of work as AOT (parsing, optimizing, emitting machine code), but it runs inside the live process rather than offline.
CuTe Python uses JIT compilation for the GPU side. The workflow looks like this: you write a function and decorate it with @cute.jit. At that point, nothing has compiled yet, it’s just a Python function holding source code. When you call cute.compile(my_kernel, ...), the Python process introspects the function, builds an internal representation (IR), and lowers it through CuTe’s internal compilation pipeline down to PTX and finally SASS, the GPU’s machine code. After this compile step finishes, the kernel is a regular compiled GPU kernel that you launch and run on the device. So the GPU never executes Python. It executes the same kind of compiled binary you’d get from nvcc, the difference is when the compilation happens (during your Python program’s run, not before it).
Why JIT and not AOT? Because the kernel’s specialization depends on values the user supplies at run time: which dtype, what tile sizes, what head dimension, what swizzle. An AOT model would force the kernel author to either pre-compile every combination (combinatorially huge) or pay runtime branching costs to handle them generically. JIT sidesteps both: the user supplies the concrete configuration, CuTe compiles exactly the specialized kernel needed for that configuration, and from then on every launch with the same configuration reuses the compiled kernel. You get the specialization benefits of CUTLASS C++ templates without having to enumerate every instantiation up front.
We’ll use “JIT-time” and “compile-time” interchangeably from here on. They refer to the same moment in the lifecycle: the window during the Python program’s execution when CuTe is lowering your @cute.jit function to PTX/SASS, before any kernel launches on the GPU. Calling it “compile-time” emphasizes “the compiler has full visibility and can optimize”; calling it “JIT-time” emphasizes “this is happening inside the running program, not in a separate offline step.” Both names point to the same thing.
This brings us to a constraint that will show up everywhere in CuTe code: many values must be JIT-time constants, meaning the compiler must be able to evaluate them at JIT time, before any kernel launches. Concretely, a JIT-time constant is a value whose magnitude is known by the time cute.compile runs, so it can be folded directly into the emitted PTX as an immediate, used to size loops the compiler will unroll, and used to validate that shapes are compatible. A runtime value, in contrast, is one that only exists once the kernel is actually executing on the GPU: an entry of a tensor, a thread index, the result of a load. The compiler doesn’t know what these values are when it’s emitting code, only that “some 32-bit float will be here when the kernel runs.”
The reason this distinction matters: CuTe’s whole performance story depends on layouts being JIT-time objects. Tile sizes, strides, swizzle patterns, and partitioning must all be known to the compiler so it can fold offset arithmetic into constants, unroll inner loops, and emit specialized instructions. If a layout’s shape were a runtime value, none of this would be possible, the compiler would have to emit generic indexing code and lose the specialization advantage. So layouts (and the values used to build them, like tile sizes, stride numbers, and stage counts) are always JIT-time. Tensor data (the actual numbers being computed on) is always runtime.
CuTe gives you two explicit markers for JIT-time values:
-
cutlass.const_expr(x)declares thatxis a JIT-time constant. You’ll see this used in conditionals (e.g.,if cutlass.const_expr(stride < M):) to tell the compiler that the branch can be decided at JIT time and the dead path discarded entirely. -
cutlass.range_constexpr(N)is the JIT-time equivalent of Python’srange(N). It signals thatNis JIT-time-known and the loop should be fully unrolled at compile time. You’ll see this used for inner loops over fixed shapes (e.g.,for j in cutlass.range_constexpr(n_cols):), where the compiler emits a flat sequence of instructions instead of a runtime loop.
The practical consequence: if you accidentally pass a runtime value where a JIT-time constant is expected, CuTe fails. The failure usually happens inside the IR lowering pass, and the Python traceback doesn’t always point cleanly at the offending line. We’ll see specific instances of this as we write kernels.
A minimal example. A taste of what a CuTe layout looks like in the DSL:
# A 4-by-3 layout. The shape says "4 along the first axis, 3 along the second."
# The strides say "step 1 along the first axis, step 4 along the second."
# Together this is a function: coord (i, j) -> integer offset i*1 + j*4.
layout = cute.make_layout(shape=(4, 3), stride=(1, 4))
# Indexing the layout function at coord (2, 1) gives offset 2*1 + 1*4 = 6.
offset = layout((2, 1)) # -> 6
That’s a layout: a shape, a stride, and the function from coordinate to offset they define. No data is involved yet, the layout doesn’t know about any pointer. It just describes how to walk an indexed grid. The next subsection unpacks layouts properly, and we’ll see why this separation between “shape + stride” and “actual data” is what makes the whole abstraction work.
Layouts
A layout is the most important object in CuTe. Almost everything else is built on top of it. So before introducing tensors, partitioning, or any of the higher-level machinery, we need to understand what a layout actually is and how it works.
Shape and stride: a function from coordinate to offset
A CuTe layout is a pair: a shape and a stride. Together they define a pure function that takes a coordinate (a tuple of integers) and returns an integer offset. That’s it. A layout by itself is just this function. It does not point to memory, hold any data, or know about pointers. It only describes how to walk an indexed grid.
Concretely, given a layout with shape \((s_0, s_1, \ldots, s_{n-1})\) and stride \((d_0, d_1, \ldots, d_{n-1})\), the layout function evaluates a coordinate \((i_0, i_1, \ldots, i_{n-1})\) as:
\[\text{layout}(i_0, i_1, \ldots, i_{n-1}) = i_0 \cdot d_0 + i_1 \cdot d_1 + \cdots + i_{n-1} \cdot d_{n-1}\]The shape tells you the bounds of each axis: how many values that axis can take. The stride tells you the step along each axis: how far you move in the linear offset when you increment one coordinate by 1.
Here is a 4-by-3 layout from the previous subsection, drawn out:
layout = cute.make_layout(shape=(4, 3), stride=(1, 4))
# coord (i, j) maps to offset i*1 + j*4
# Visualizing the offsets at each (i, j):
#
# j=0 j=1 j=2
# i=0: 0 4 8
# i=1: 1 5 9
# i=2: 2 6 10
# i=3: 3 7 11
#
# coord (2, 1) -> offset 2*1 + 1*4 = 6
The first axis has 4 entries and stride 1, so moving along it walks through consecutive offsets 0, 1, 2, 3. The second axis has 3 entries and stride 4, so moving along it jumps by 4 each time. The layout function is just doing the dot product.
This is exactly the definition you’d use for a regular row-major or column-major 2D array, just expressed as shape and stride explicitly. If you want a 4x3 column-major matrix (rows contiguous), the layout is shape=(4, 3), stride=(1, 4). If you want it row-major (columns contiguous), it’s shape=(4, 3), stride=(3, 1). The shape stays the same, the stride changes, and CuTe handles both as the same kind of object.
Layouts describe walks, not storage
The crucial property: a layout describes a walk, not a piece of memory. Two completely different memory regions can share the same layout (the same indexing pattern), and the same memory region can be described by multiple layouts (different indexing patterns over the same bytes). We’ll see both of these in the next subsection on tensors and aliasing.
This separation is what lets CuTe do its job. Once a layout is in the compiler’s hands, it can analyze the indexing pattern, compose it with other layouts (more on this below), specialize tile arithmetic, and validate that two layouts are compatible, all without touching any data. The data shows up later, when we wrap a layout with a pointer to create a tensor.
Layouts are JIT-time objects
Layouts are almost always JIT-time values, not runtime values. When cute.compile runs, it knows every layout’s shape and stride concretely. That’s what lets the compiler fold offset arithmetic into immediates, unroll loops over fixed shapes, and statically check that the SMEM you’re partitioning has the structure the Tensor Core expects. This is the same constraint we discussed in the JIT subsection: the layout itself is metadata the compiler reasons about, while the data the layout points to is the runtime tensor.
If you tried to construct a layout from a runtime value (say, an entry of a tensor), CuTe would fail. The shape and stride must be JIT-time-known. This is why you’ll see layout construction always using either literal integers or cutlass.const_expr-wrapped expressions.
Hierarchical (nested) shapes
So far, layouts have looked like simple flat shape and stride tuples. The move that makes CuTe powerful is allowing the shape (and stride) to be nested:
layout = cute.make_layout(shape=((2, 2), 3), stride=((1, 4), 8))
# coord ((i0, i1), j) -> offset i0*1 + i1*4 + j*8
The outer rank is 2 (there are two top-level modes: the nested pair (2, 2) and the scalar 3). The first top-level mode is itself a tuple, with its own shape (2, 2) and stride (1, 4). A coordinate along this nested mode is itself a pair: (i0, i1). The layout function then sums all stride contributions, from both the inner and outer modes.
We need to be careful about a piece of CuTe vocabulary here. A mode is one top-level axis of a layout, and a mode can be either flat (a single integer shape and stride) or nested (a tuple of sub-shapes, each with its own stride). The terminology is: a flat mode has one shape and one stride. A nested mode has a tuple of sub-shapes, each with a corresponding sub-stride, and the offset contributed by that mode is the sum over all sub-contributions.
This is the key bit: a nested mode is one logical axis of the layout, but its internals fan out into multiple stride contributions. A single logical coordinate along that axis decomposes into a tuple, and the offset for that mode comes from summing strides across the tuple.
Why does this matter? Because real hardware geometries are hierarchical. Consider an axis of size 64 described as a nested mode with shape \((8, 2, 4)\) and strides \((1, 16, 32)\). A logical coordinate \(m \in [0, 64)\) along this axis decomposes into a tuple \((m_0, m_1, m_2)\) where \(m_0 \in [0, 8)\), \(m_1 \in [0, 2)\), \(m_2 \in [0, 4)\). The decomposition uses a positional system, similar to place-value in decimal but with the shape sizes as bases: \(m_0\) is the innermost coordinate, \(m_1\) steps after \(m_0\) has covered all 8 of its values, and \(m_2\) steps after \(m_1\) has covered both of its values. So the formula is:
\[m = m_0 + 8\, m_1 + (8 \cdot 2)\, m_2 = m_0 + 8\, m_1 + 16\, m_2\]The multipliers \(1, 8, 16\) are the running products of the inner shapes (1, then 8, then \(8 \cdot 2 = 16\)). This is how a flat index maps into a hierarchical coordinate.
The memory offset is a separate calculation, using the strides instead of these positional multipliers:
\[\text{offset}(m) = m_0 \cdot 1 + m_1 \cdot 16 + m_2 \cdot 32\]So one logical \(m\) coordinate pulls contributions from three different strides. This is exactly how WGMMA describes its M axis: the hierarchy is the hardware geometry. The question “is the M dimension a single contiguous stride?” no longer has a yes-or-no answer. It has a hierarchical answer, where some bits of \(m\) walk one stride and other bits walk a different stride.
If this feels abstract right now, that’s fine. The point to internalize is that nested modes let one logical axis encode multiple stride patterns at once, and that’s what makes layouts powerful enough to describe Tensor-Core fragments, swizzled SMEM, and thread/value partitioning, all with the same primitive. We’ll see this used concretely when we get to TV layouts and WGMMA fragments later in this section.
Operations on layouts
CuTe provides a small set of layout operations that you’ll see throughout FA3 kernel code. The full list is larger; these are the ones we actually use.
cute.size(layout) returns the total number of elements the layout addresses. For a flat layout it’s just the product of the shape; for a hierarchical layout it’s the product of all leaf shape values.
layout = cute.make_layout(shape=((2, 2), 3), stride=((1, 4), 8))
cute.size(layout) # -> 2 * 2 * 3 = 12
cute.size(layout, mode=[0]) # -> 2 * 2 = 4 (size of mode 0, including its nested structure)
cute.size(layout, mode=[1]) # -> 3
The mode= argument is important. It takes a list (because modes can themselves be nested) and returns the size of that mode, handling hierarchy correctly.
cute.rank(layout) returns the number of top-level modes. For the layout above, cute.rank(layout) == 2. Rank counts top-level modes only; it doesn’t peer inside nested modes.
cute.slice_(layout, coord) takes a coordinate where some entries are integers (which get fixed) and some are None (which stay free), and returns a new layout over only the free axes. This is how you take a slice through a higher-dimensional layout to look at a sub-region. The fixed-coordinate values contribute their stride contributions to a constant offset baked into the new layout; the free axes keep their original shape and stride.
# 3D layout, shape (4, 3, 2), strides (1, 4, 12).
layout = cute.make_layout(shape=(4, 3, 2), stride=(1, 4, 12))
# Fix the third coord at 0, leave the first two free.
sliced = cute.slice_(layout, (None, None, 0))
# `sliced` is now a 2D layout with shape (4, 3) and stride (1, 4).
# It represents: "the layout you get if you walk the first two axes of
# the original 3D grid, with the third axis pinned to index 0."
# Indexing sliced[(i, j)] is equivalent to indexing the original at (i, j, 0).
Slicing doesn’t reduce the layout to a different function. It just fixes some inputs of that function and exposes a lower-rank view over the rest. In a kernel, you use this constantly: a per-CTA tile is a slice of the full GMEM tensor’s layout (with the CTA’s tile-coords fixed), a per-stage SMEM view is a slice with the pipeline-stage index fixed, and so on.
cute.append(a, b) extends a layout by adding b as a new outer mode. It’s the inverse direction of slicing: you build up a higher-rank layout by appending modes. We’ll see this used in the TV-layout subsection when we surgically reassemble layouts from extracted pieces.
cute.flat_divide(tensor, tile_shape) is the workhorse for partitioning a tensor into tiles. Given a tensor and a tile shape, it returns a new tensor whose layout has the tile interior at the front and the tile coordinates at the back.
# Suppose mQ has logical shape (seqlen, head_dim) = (1024, 128).
# Tile it into (Br, d) = (64, 128) tiles.
tiled = cute.flat_divide(mQ, (Br, d))
# tiled has shape (Br, d, n_M_tiles)
# ^^^^^^^^ ^^^^^^^^^
# one tile how many tiles
# Here n_M_tiles = 1024 / 64 = 16.
You then index tiled[(None, None, m)] to get the \(m\)-th tile, which still has the same Br-by-d shape. This is how every kernel in this blog gets its per-CTA tile out of the full GMEM tensor.
These four primitives (size, rank, slice_, append, plus flat_divide) are most of what you need to navigate layouts in CuTe code. Specific kernels will use a few more as they come up.
Layouts never touch data. This is worth stating explicitly before we move on. Every operation in this subsection (size, rank, slice_, append, flat_divide) operates only on the layout: it rearranges, slices, or extends the coordinate-to-offset function. None of these operations read from memory, copy bytes, or modify the underlying data buffer. They’re pure compile-time transformations of an indexing pattern.
The actual data shows up only when we wrap a layout with a pointer to create a tensor (the next subsection). At that point, indexing the tensor at coordinate \(c\) does two things: (1) compute the offset via the layout function \(\text{layout}(c)\), and (2) read or write the byte at that offset. Step 1 is a pure arithmetic operation the compiler can heavily optimize. Step 2 is the only step that ever touches memory.
So when a kernel calls cute.flat_divide(mQ, (Br, d)), no Q data moves. The result is a re-layout of mQ: same underlying bytes, new coordinate function for walking them. Same when a kernel slices a layout to pick out one stage of a circular buffer, or appends a mode to extend a fragment shape. Layouts are the planning step; the work itself happens later, when data flows through cute.copy (memory motion) or cute.gemm (Tensor-Core compute).
With that in hand, we can introduce the object that pairs a layout with actual data: the tensor.
Tensors
A layout describes a walk over an indexed grid, but no data. A tensor is the object that pairs a layout with actual data, by attaching a pointer. Once you have both, you can index, read, and write.
A tensor is an (iterator, layout) pair
In CuTe, a tensor is:
\[\text{tensor} = (\text{iterator},\; \text{layout})\]The iterator is a pointer-like handle that holds a base address along with type information (the dtype of the elements it points to). It is not a Python iterator in the for x in it sense. Conceptually, treat it as a typed C++ pointer: it knows the address it points to and the size of each element, but nothing about shape or strides. Those are entirely the layout’s job.
Indexing a tensor at coordinate \(c\) is then a two-step operation:
\[\text{tensor}[c] \;=\; \ast(\text{iterator} + \text{layout}(c))\]Compute the offset using the layout function, add it to the iterator, dereference. The layout decides where in memory; the iterator decides what’s there.
# Suppose `some_pointer` is a typed pointer to a buffer of FP32 values in some memory space.
tensor = cute.make_tensor(some_pointer, my_layout)
# Indexing the tensor at coord (2, 1):
val = tensor[(2, 1)] # equivalent to *(some_pointer + my_layout((2, 1)))
The iterator stays the same across operations; only the layout changes if you re-view the tensor differently. This is the key separation that makes CuTe’s machinery work: data is one thing, indexing is another.
Aliasing: two views into the same data
Because a tensor is just an iterator and a layout, a direct consequence is that two tensors can share an iterator but use different layouts. The two tensors point at the same physical bytes (registers, SMEM, GMEM, whichever) but interpret those bytes through different coordinate functions. No data is copied, no memory motion happens. You just have two different ways to walk the same storage.
This is called aliasing, and it’s used constantly in kernel code. The most important reason: hardware-imposed layouts and human-friendly layouts often disagree, and aliasing lets you write code under the human-friendly view while the hardware still sees the layout it expects.
A concrete example from FA3 (we’ll see this again in K4):
# `acc_pv` is a register tensor allocated by CuTe for the PV WGMMA accumulator.
# Its layout is whatever hierarchical structure the WGMMA atom expects, which is
# not a clean 2D rectangle from a programmer's point of view.
acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape)
# We want a friendlier 2D view of the same physical registers, where mode-0 is
# "M row in this thread's view of the tile" and mode-1 is "N column."
# Same iterator (same registers), different layout.
acc_pv_mn = cute.make_tensor(
acc_pv.iterator,
self.layout_acc_mn(pv_tiled_mma, acc_pv.layout),
)
Both acc_pv and acc_pv_mn refer to the exact same registers. Writing acc_pv_mn[i, j] = x is sugar over a register store, indexed in a way humans can reason about. The store lands in whichever physical register the alias’s layout maps (i, j) to, which is the same physical register the hardware-native acc_pv knows about.
Aliasing is a no-op at runtime. The compiler sees two layout objects over one iterator and generates indexing code for whichever view is being used at each point. No bytes move. No registers are reallocated. The two tensors are the same data, indexed two ways.
Identity tensors: coordinate-reporting views
CuTe also lets you build a tensor whose values are coordinates rather than data. Indexing it at coord \(c\) returns \(c\) itself, as a tuple. This is called an identity tensor, and it has no underlying storage. It exists only as a probe to ask the question: “for thread T’s slice of some partitioned tensor, what tile-coordinates do thread T’s elements correspond to?”
# An identity tensor over the shape (Br, d).
# cO[(m, n)] returns the tuple (m, n) itself, not a data value.
cO = cute.make_identity_tensor((self.Br, self.d))
Identity tensors come into their own when partitioned the same way a real tensor is partitioned. If you apply the same partitioning that distributes an output tile across threads to an identity tensor over that same tile, indexing the partitioned identity at a per-thread slot tells you which \((M, N)\) position in the original tile that slot corresponds to. We’ll see this concretely in the partitioning subsection and again when we write the output scatter in a kernel.
The takeaway here: identity tensors are coordinate probes. They piggyback on the same layout/partitioning machinery the real tensors use, so they always agree with the real tensors on which thread owns what. They’re a debugging and indexing tool, not a data structure.
Tensor data is runtime; layout is JIT-time
We touched on this in the JIT subsection but it’s worth restating here in the tensor context. When you write tensor[c], two things happen at different times:
- The offset computation \(\text{layout}(c)\) is purely arithmetic on JIT-time-known shape and stride values. The compiler can fold offsets into immediates, unroll inner loops, and emit specialized addressing.
- The dereference is a runtime memory operation. Whatever value is at that byte address right now is what comes back.
So the layout is part of the kernel’s compiled structure (baked in at JIT time), and the actual data flows through at runtime. This is what lets CuTe specialize aggressively: the indexing code is a sequence of cheap integer additions on compile-time constants, and the loads and stores hit the right registers or memory locations directly.
With layouts (coordinate functions) and tensors (layouts plus pointers) in place, the next question is: when a kernel runs with hundreds of threads, how does each thread know which subset of a tensor to operate on? That’s what partitioning answers.
Partitioning and the TiledMMA object
A kernel runs with many threads (one warpgroup = 128 threads on Hopper, often more), and they cooperate on a tile. The question this subsection answers: given a tile in SMEM, in registers, or in GMEM, which slice of it does each thread own, and how does CuTe express that?
The short answer: there’s an object called TiledMMA that knows everything about how a warpgroup-level Tensor-Core instruction distributes work across threads. It carries the partitioning recipes for the three operands \(A\), \(B\), \(C\), and we use it to produce per-thread tensors that line up with what the hardware expects. Most of the deep mechanics (the exact fragment layouts, which thread holds which value of a WGMMA output) live with K4. Here we cover the abstractions and enough of the API to read kernel code.
The TiledMMA object
A TiledMMA is a CuTe object that bundles three things:
- An MMA atom: the Tensor-Core instruction we want to issue. For our kernels this is WGMMA on Hopper, with a specific shape and dtype combination (e.g., a \(64 \times 64 \times 16\) BF16 atom).
- An atom layout: how many copies of the atom we tile across the CTA’s work region (e.g., two atoms along the M direction, one along N and K).
- The partitioning recipes that say how each atom’s \(A\), \(B\), and \(C\) tiles are distributed across the warpgroup’s 128 threads.
In CuTe Python we build it with a helper. From FA3:
qk_tiled_mma = sm90_utils.make_trivial_tiled_mma(
INPUT_DTYPE, INPUT_DTYPE, # dtypes of A and B
warpgroup.OperandMajorMode.K, # A is K-major
warpgroup.OperandMajorMode.K, # B is K-major
ACC_DTYPE, # accumulator dtype
self.atom_layout_mnk, # tiling across atoms
(self.Br, self.Bc), # the CTA's work tile shape (M, N)
warpgroup.OperandSource.SMEM, # where A comes from
)
This is one TiledMMA, configured for the QK matmul (\(S = Q K^\top\)). FA3’s PV matmul (\(O = P V\)) gets its own TiledMMA with different settings (e.g., \(A\) sourced from registers, since the freshly-computed \(P\) from the previous matmul lives in registers). The key point for now: a TiledMMA is a per-matmul object, configured once on the host (well, JIT-time), and reused for every iteration of the inner loop.
What about the operand-source choice (“\(A\) from SMEM”), the major modes, and the atom layout? These are decisions about how the matmul talks to the hardware: where the data lives, in what byte order, and how many parallel atom-issues we want. They genuinely matter for performance, but they’re decisions you make when designing the kernel (K3 introduces them with TMA, K4 with WGMMA). For now, treat the TiledMMA as an opaque object that knows everything about partitioning.
Getting a per-thread slice: get_slice
A TiledMMA describes the warpgroup-level partitioning. To get the slice for a specific thread, you ask the TiledMMA for it:
thr_mma = tiled_mma.get_slice(tidx)
thr_mma is a thread-local view of tiled_mma that knows which thread you are (tidx is the thread index inside the warpgroup, 0 to 127). Every subsequent partition_* and make_fragment_* call on thr_mma will produce a tensor that is this thread’s piece of the warpgroup-level tile.
Importantly, calling get_slice is cheap. The TiledMMA itself holds JIT-time partitioning metadata; get_slice just baking the thread index in.
partition_A, partition_B, partition_C
The three partitioning calls are:
tCsQ = thr_mma.partition_A(sQ_full) # this thread's slice of A, in SMEM
tCsK = thr_mma.partition_B(sK_full) # this thread's slice of B, in SMEM
tCgC = thr_mma.partition_C(gC_tile) # this thread's slice of C, in GMEM
Each call takes a source tile (anywhere in memory) and returns a tensor whose layout is this thread’s slice under the corresponding operand’s partitioning recipe. Crucially, the returned tensor still points at the original storage. Partitioning doesn’t copy any bytes; it builds a re-layout that walks only this thread’s portion of the tile.
So tCsQ has the same iterator as sQ_full (it points to the same SMEM location), but its layout is the per-thread \(A\) partitioning that the WGMMA atom expects, sliced down to thread tidx. The compiler knows exactly which bytes this thread will touch.
Two subtleties worth flagging up front:
- For WGMMA’s SS variant (where both \(A\) and \(B\) come from SMEM),
partition_Aandpartition_Bare slightly degenerate: every thread of the warpgroup gets the same layout view, because the Tensor-Core hardware reads SMEM directly via a descriptor, not via per-thread loads. We’ll see why in K4. The namepartition_Ais a holdover from the more general API where per-thread loads do exist; in SS mode it’s really a “reshape into the WGMMA-expected layout” rather than a true thread-by-thread slicing. For our purposes the API call is the same regardless. - For \(C\) (the accumulator), the partitioning is genuinely per-thread, because the accumulator lives in registers and each thread owns a specific subset of the output tile’s values.
partition_C(gC_tile)returns this thread’s slice of where its register-held values should land in GMEM.
make_fragment_A, make_fragment_B, make_fragment_C
A fragment is the WGMMA-encoded view of an operand, ready to feed to the Tensor Core. Two variants:
- For SMEM-sourced operands,
make_fragment_A(tCsA)wraps the partitioned SMEM tensor in the exact layout encoding WGMMA’s SMEM-descriptor instruction expects. Same iterator (still pointing into SMEM), new layout. No register allocation happens. - For register-sourced operands,
make_fragment_C(shape)actually allocates a per-thread register tensor of the right size to hold this thread’s portion of the accumulator. This is a real allocation; the tensor’s iterator points at freshly reserved registers.
# SMEM-sourced A (QK matmul): reinterpret tCsQ in the WGMMA-A layout.
tSrQ = qk_tiled_mma.make_fragment_A(tCsQ)
# Register-allocated C: this is where the warpgroup will accumulate.
qk_acc_shape = qk_thr_mma.partition_shape_C((self.Br, self.Bc))
acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape)
The first call doesn’t move data; the second one allocates registers. Both produce tensors that cute.gemm knows how to feed directly to a WGMMA instruction.
A note on what’s actually in registers per thread: a single \(64 \times 64\) WGMMA atom produces \(4096\) output values, computed cooperatively by all 128 threads of the warpgroup. So each thread holds \(4096 / 128 = 32\) of those values in its registers. These 32 values are not a contiguous sub-rectangle of the tile; they’re scattered across it in a pattern fixed by the WGMMA hardware. This scatter pattern is the heart of TV layouts, which we get to next. K4 covers the exact pattern.
A note on SMEM descriptors
We’ve said twice now that the Tensor Core “reads SMEM directly via a descriptor.” It’s worth to anchor what this means, since the term shows up in any WGMMA discussion you’ll encounter.
A SMEM descriptor is a 64-bit value the Tensor Core uses to find an operand tile in SMEM. It packs the tile’s base address, the strides needed to walk between sub-blocks of the tile, the swizzle pattern applied to the SMEM layout, and an alignment offset, everything the hardware needs to walk the entire operand tile on its own. The crucial property: with the descriptor in hand, the Tensor Core does not need per-thread loads to assemble the operand. It reads SMEM directly through hardware-managed addressing, using the descriptor as the address recipe.
This is why make_fragment_A and make_fragment_B perform no data motion when the operand is sourced from SMEM: there’s nothing to move. The “fragment” is a layout reinterpretation that lets CuTe construct the right descriptor when cute.gemm issues the WGMMA instruction. Every thread of the warpgroup ends up holding the same 64-bit descriptor value in a register, because the descriptor describes warpgroup-level data and all 128 threads consume it identically. By contrast, register-sourced operands (used in FA3’s PV matmul, where the freshly-computed P from QK’s softmax lives in registers) need to actually have their data in the right physical registers per thread, because there’s no descriptor mechanism for register operands. K4 covers the descriptor’s internal fields and how the descriptor mechanism enables WGMMA’s asynchrony.
TV layouts: thread-and-value partitioning
The partitioning recipes inside a TiledMMA are themselves CuTe layouts, with a specific structure. Two top-level modes:
\[(T, V) \;\longrightarrow\; \text{position in the tile}\]The first mode is the thread index \(T\) (which thread in the warpgroup, \(T \in [0, 128)\)). The second mode is the value index \(V\) (which one of this thread’s owned values, \(V \in [0, 32)\) for the \(64 \times 64\) case). The layout function tells you, for any \((T, V)\) pair, which position in the underlying tile that value corresponds to.
Every TiledMMA exposes three TV layouts, one per operand:
tiled_mma.tv_layout_A
tiled_mma.tv_layout_B
tiled_mma.tv_layout_C
The crucial property: both the \(T\) and \(V\) modes can be hierarchical, and both contribute to both axes of the tile. Sub-bits of the thread index walk along \(M\); other sub-bits walk along \(N\); same for the value index. So a single thread’s 32 values are scattered across the \(64 \times 64\) tile in a hardware-defined pattern, not arranged as a clean sub-rectangle. This is the structural reason for everything that’s awkward about per-thread reasoning in WGMMA code, and it’s what makes hierarchical layouts essential rather than decorative.
The reason this matters for kernel code: when softmax wants rowmax and rowsum, it’s reducing over a row of the score tile. A given row’s values are split across multiple lanes’ registers, so the reduction needs warp-level coordination (warp shuffles) to combine the per-lane contributions. The pattern of which lanes share a row is encoded in tv_layout_C’s strides. K4 unpacks this carefully; here we just establish that TV layouts are the object that holds it all.
The tXrY / tXsY naming convention
Reading CuTe kernel code requires understanding a naming convention that’s used consistently throughout. Every partitioned tensor has a four-character name of the form:
\[t\langle X \rangle\langle \text{mem} \rangle\langle \text{Operand} \rangle\]-
t(“thread-partitioned”) is always there. It tells you this tensor came from a partitioning call. -
Xis which matmul this tensor is for. In FA3,Smeans the QK matmul (because the output is \(S\), the score matrix), andOmeans the PV matmul (because the output is \(O\)). - mem is where the data lives:
sfor SMEM,rfor registers,gfor GMEM,cfor coordinates (the identity-tensor case from the previous subsection). - Operand is which operand role:
Q,K,V,P,S,O.
A few concrete decodings from FA3 kernel code:
-
tSsQ= QK matmul, thread-partitioned, SMEM-resident, Q-operand. This is the per-thread slice of \(Q\) in SMEM, the input to the QK matmul. -
tSrK= QK matmul, thread-partitioned, register-fragment view, K-operand. This is the WGMMA-A-encoded view of \(K\) that gets fed tocute.gemm. For SMEM-sourced K, this is a layout reinterpretation, not a register allocation. -
tOrP= PV matmul, thread-partitioned, register-resident, P-operand. The output of the QK softmax, reinterpreted as the \(A\) operand of PV. -
tOcO= PV matmul, thread-partitioned, coordinate tensor, O-destination. Used as a coordinate probe for the output scatter (the identity-tensor pattern from the previous subsection).
The convention is just a convention, not enforced by CuTe. But every kernel in this blog follows it strictly, and most kernels in CUTLASS-land do too. Once you internalize it, you can decode unfamiliar kernel code at a glance: see tOsV, know it’s “PV matmul’s \(V\) tile in SMEM, partitioned per thread,” and move on.
Putting it all together. To make this section concrete, here’s the full chain of objects you build up for one matmul in a CuTe kernel, end to end. You start on the host side (at JIT time) by constructing a TiledMMA for that matmul, which bundles three things: the MMA atom (the specific WGMMA instruction, with its fixed dtype and shape, e.g. \(64 \times 64 \times 16\) BF16), the atom layout (how many atoms tile across the CTA’s work region, e.g. two atoms along M), and the TV layouts (tv_layout_A, tv_layout_B, tv_layout_C) that encode how the warpgroup’s 128 threads split each operand’s tile. The TiledMMA itself doesn’t touch any data; it’s a pure JIT-time description of how the matmul will be carried out by the hardware.
Once you’re inside the kernel (on the device, at runtime, though all the layout machinery is still JIT-time), each thread asks the TiledMMA for its own slice: thr_mma = tiled_mma.get_slice(tidx). This thread-local view of the TiledMMA is what every subsequent partition_* and make_fragment_* call goes through. It is cheap, it just bakes the thread index into the TiledMMA’s metadata, no data motion.
Then you partition the operand tiles. The full operand tiles live somewhere (in SMEM, after a TMA load, or in registers, after a previous matmul). You hand each one to its matching partition call: tCsA = thr_mma.partition_A(sA_full) for the A operand, tCsB = thr_mma.partition_B(sB_full) for B, tCgC = thr_mma.partition_C(gC_tile) for C. Each call returns a partitioned tensor: same iterator as the source (no bytes moved), but a new layout that is this thread’s slice under the corresponding TV layout. The compiler now knows exactly which bytes this thread will touch.
The partitioned tensors are not yet in the form WGMMA wants. The Tensor Core expects operand tiles wrapped in a specific encoded layout, and the accumulator wrapped in a register-allocated fragment. So you wrap them with make_fragment_*: tCrA = tiled_mma.make_fragment_A(tCsA) reinterprets the SMEM-resident A in the WGMMA-A encoding (no data motion, since the Tensor Core will read SMEM directly via a descriptor; the “fragment” here is just a layout-level reshape). tCrB = tiled_mma.make_fragment_B(tCsB) does the same for B. acc = thr_mma.make_fragment_C(shape) is different: it genuinely allocates per-thread registers to hold this thread’s share of the accumulator (32 values per thread for a \(64 \times 64\) atom, scattered across the tile in the hardware-defined Z-pattern). This is the only call in the chain that performs a real allocation.
At this point you have three per-thread tensors (tCrA, tCrB, acc) ready to feed to cute.gemm. Every byte these tensors point at is in the right place and the right layout for WGMMA. The matmul itself, the actual cute.gemm dispatch that issues WGMMA instructions, is what the next subsection covers.
A useful summary, then, of what each step does and where the data lives:
- TiledMMA — JIT-time, no data. The complete description of how the matmul maps to hardware.
-
get_slice(tidx)— JIT-time, no data. A per-thread view of the TiledMMA. -
partition_*— runtime tensors, but no data motion. Re-layouts of the source tiles, one per operand, sliced to this thread. -
make_fragment_A/make_fragment_B— runtime tensors, but no data motion. Layout reinterpretations that match the WGMMA encoding. For SMEM-sourced operands these are pure layout reshapes; the Tensor Core will read SMEM directly via a descriptor (more on this in K4). -
make_fragment_C— runtime tensor, real allocation. Per-thread registers reserved to hold the accumulator.
Everything before cute.gemm is preparation: arranging layouts and reserving registers. The actual matmul, and the rest of the kernel’s data motion (TMA loads, SMEM-to-register movement), is what consumes the tensors we’ve built. We turn to cute.gemm next.
cute.gemm: dispatching the matmul
We now have everything cute.gemm consumes: a TiledMMA that knows how the matmul maps to hardware, and three per-thread tensors (tCrA, tCrB, acc) holding the partitioned operands in the right encoding. The call is short:
cute.gemm(tiled_mma, acc, tCrA, tCrB, acc)
That’s the entire dispatch. Behind it, CuTe figures out which WGMMA instructions to issue, in what order, with what accumulator state. This subsection unpacks what the call actually does, at the abstraction level. The hardware-side details (the inline PTX, the instruction pipeline, the synchronization) live with K4.
What cute.gemm consumes
The arguments line up directly with the matmul we want:
-
tiled_mma— the TiledMMA built earlier. It carries the MMA atom, the atom layout, and the TV layouts. Everything CuTe needs to emit the right instruction is in here. -
acc— the destination accumulator (this thread’s register fragment), passed both as the write destination and the read source for the accumulated value. Same tensor in both positions because WGMMA computesC = A @ B + C(orC = A @ Bwhen accumulation is off, controlled separately). -
tCrA,tCrB— the per-thread A and B fragments, as produced bymake_fragment_*. For SMEM-sourced operands these are descriptor-backed layout reinterpretations; for register-sourced operands they hold the actual data.
The dispatch is per-thread, but the effect is warpgroup-level: when all 128 threads call cute.gemm together, the Tensor Core sees a coordinated set of operand descriptors and accumulator registers, and it issues the actual WGMMA instructions for the warpgroup.
The two-level dispatch
Internally, cute.gemm has a two-level structure that’s worth understanding because it makes the shapes of tCrA, tCrB, and acc make sense.
The partitioned tensors are not single-atom-sized. The TiledMMA’s MMA atom is a specific shape (e.g., \(64 \times 64 \times 16\)), but the CTA’s work tile usually requires multiple atom-issues to cover. So each partitioned tensor has a layout shape with one mode for one atom’s worth of data (called the V mode, for “value”) and additional modes counting how many atoms fit along each axis (MMA_M, MMA_N, MMA_K).
The conventional shape for the A fragment is:
\[\text{tCrA shape} = (V,\; \text{MMA\_M},\; \text{MMA\_K})\]For B:
\[\text{tCrB shape} = (V,\; \text{MMA\_N},\; \text{MMA\_K})\]For C (the accumulator):
\[\text{acc shape} = (V,\; \text{MMA\_M},\; \text{MMA\_N})\]For a concrete example: with a \(64 \times 64 \times 16\) MMA atom and a CTA work tile of \(B_r \times d = 64 \times 64\) for the QK matmul, MMA_M = 1 (one atom along M), MMA_N = 1 (one along N), MMA_K = 4 (four atoms along K, since the K dimension is \(d = 64\) and each atom covers \(K = 16\)). The V mode is one atom’s per-thread data, encoded as whatever layout the WGMMA atom expects (a single descriptor for A/B; 32 register slots for C).
Given that structure, cute.gemm works in two levels:
- The outer level walks the atoms. It loops over MMA_M, MMA_N, and MMA_K, visiting each \((m, n, k)\) position in turn. The dimensions are JIT-time constants, so this loop is fully unrolled at compile time, the SASS that runs on the GPU is a flat sequence of WGMMA issues, not an actual loop.
- The inner level issues one WGMMA per atom. At each \((m, n, k)\) position,
cute.gemmslices out the V-mode of each input tensor (one atom’s worth of data, for A this is a single descriptor; for the accumulator this is the 32 register slots this thread owns for the atom). With one atom’s data isolated, it calls into the MMA atom, which issues exactly one WGMMA instruction. The PTX side of this issue, the inline assembly, the operand registers, the immediate flags, is what K4 covers.
So one cute.gemm call expands into MMA_M × MMA_N × MMA_K WGMMA issues. In our QK example that’s \(1 \times 1 \times 4 = 4\) WGMMA instructions per call. The K-reduction loop (4 atoms along K) accumulates into the same acc slot, summing partial products as K is swept. The M and N loops produce different output regions; for QK with MMA_M = MMA_N = 1, there’s only one output region per call.
This is also why the V mode exists as a distinct first mode: it’s the unit cute.gemm’s inner dispatch operates on. After stripping the outer (m, n, k) indices, V is one atom’s worth of data, which is exactly what mma_unpack and the MMA atom’s fma method consume.
The ACCUMULATE flag
WGMMA can either accumulate into the existing C state (\(C = A B + C\)) or write fresh (\(C = A B\)). This is controlled by an ACCUMULATE field on the TiledMMA, set per-issue:
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, accumulate_flag)
cute.gemm(tiled_mma, acc, tCrA_k, tCrB_k, acc)
For the K-reduction inside one matmul, you want the first K-step to write (so accumulate_flag = False, since acc may contain garbage at the start), and every subsequent K-step to add (accumulate_flag = True). FA3’s kernel code expresses this with ACCUMULATE = (k_block_idx != 0).
For matmuls whose accumulator is a running quantity across multiple kernel iterations (like PV in FlashAttention, where the output is accumulated across all KV chunks), ACCUMULATE is always True, and the accumulator is zero-initialized once at the start.
What cute.gemm does not do
A few things are worth flagging because they’re easy to assume cute.gemm handles automatically:
- It does not synchronize. WGMMA is asynchronous:
cute.gemmissues the instructions and returns. The actual computation continues on the Tensor Core in parallel with the issuing warpgroup. You need explicit synchronization (afencebefore the issue, acommit_groupandwait_group(N)to wait for results) bracketing the call. K4 covers this in detail; for now, register that the synchronization is the caller’s responsibility. - It does not load data into SMEM or registers. The partitioned tensors must already point at the right data when
cute.gemmis called. The data motion (TMA loads for SMEM, the QK→PV register handoff for the register-sourced case) is separate code that runs before the dispatch. - It does not allocate the accumulator. That happened back in
make_fragment_C.cute.gemmreads and writes the existing registers. - It does not unroll itself at runtime. The outer (MMA_M, MMA_N, MMA_K) loop is fully unrolled at JIT time, because the dimensions are JIT-time constants. The kernel that hits the SASS is a flat sequence of WGMMA instructions, not a loop.
The full structure around a cute.gemm call
Putting this together with the synchronization primitives (covered in detail in K4), a complete WGMMA dispatch looks roughly like:
cute.nvgpu.warpgroup.fence() # register operands ready
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0)
cute.gemm(
tiled_mma,
acc,
tCrA[(None, None, k_block_idx)],
tCrB[(None, None, k_block_idx)],
acc,
)
cute.nvgpu.warpgroup.commit_group() # close the batch
cute.nvgpu.warpgroup.wait_group(0) # wait for it
The range_constexpr loop and the per-iteration ACCUMULATE toggle handle the K-reduction. The slicing tCrA[(None, None, k_block_idx)] selects the k-th k_block from tCrA’s shape (V, MMA_M, MMA_K), picking out one atom along K. The fence-commit-wait triplet brackets the asynchronous WGMMA pipeline.
We’re handwaving over the synchronization here. The detailed semantics, why the fence is needed, what commit_group and wait_group actually do, why Hopper’s same-accumulator-shape exception lets multiple WGMMAs pipeline without internal fences, all live with K4. Here the point is just that cute.gemm is the dispatch primitive that produces the actual WGMMA instructions, and that synchronization is a separate concern wrapping it.
With cute.gemm covered, the only major piece of CuTe machinery we haven’t introduced is the data-motion side: how data gets into SMEM from GMEM in the first place. That’s what cute.copy and TMA atoms handle, and they’re the subject of the next subsection.
cute.copy and TMA atoms: dispatching data motion
We’ve covered how a kernel does its math (cute.gemm). Now we need the other side: how the data actually gets to where the math needs it. SMEM doesn’t fill itself; the kernel has to issue loads from GMEM. On Hopper, the right way to do this is TMA (Tensor Memory Accelerator), the dedicated hardware unit for bulk asynchronous tile copies. CuTe wraps TMA with the same atom-and-dispatch pattern we just saw for WGMMA. This subsection covers that pattern at the abstraction level, the deep TMA mechanics (descriptors, mbarriers, pipelines) live with K3.
What TMA is, briefly
TMA is the SM90 hardware unit that performs bulk asynchronous copies of multidimensional tiles between GMEM and SMEM. Three properties make it different from the pre-Hopper cp.async model that you may have seen in earlier-generation kernels:
- One thread issues the whole copy. A single thread executes the TMA instruction; the hardware does the rest.
- The tile shape is multidimensional and known to hardware. Before the kernel runs, you build a tensor descriptor that captures the source tensor’s shape, strides, and dtype. The hardware uses this descriptor to walk the source tile correctly without per-thread address computation.
- The copy is asynchronous. TMA issues return immediately; the thread that issued continues executing while the copy proceeds in the background. Completion is tracked by a barrier.
For our purposes here, we just need to know that TMA exists, that we use it to load Q, K, and V tiles from GMEM into SMEM, and that CuTe gives us a primitive to issue these copies. K3 covers TMA’s mechanics (descriptors, mbarriers, pipelines) in depth.
TMA atoms
Just as a WGMMA matmul is wrapped by an MMA atom, a TMA copy is wrapped by a TMA atom. The TMA atom bundles three things:
- The TMA operation type (e.g., a single-CTA GMEM-to-SMEM bulk tile copy).
- A tensor descriptor for the GMEM source (its shape, strides, dtype), built once at JIT time.
- The SMEM destination layout (which tile structure in SMEM the copy lands in, with its swizzle pattern).
In CuTe Python you build it with a helper:
tma_atom_q, tma_tensor_q = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(), # the op: G2S (GMEM→SMEM) bulk tile
mQ, # GMEM source tensor (any layout)
sQ_layout, # SMEM destination layout
(self.Br, self.d), # tile shape to copy per issue
num_multicast=1, # no broadcast across CTAs
)
This returns two things: the TMA atom itself (tma_atom_q, the handle holding the descriptor and metadata, used in subsequent dispatch calls) and a TMA-compatible re-wrap of the source tensor (tma_tensor_q). From here on, use tma_tensor_q, not mQ, when partitioning for TMA. cute.copy from raw mQ will not work.
Building a TMA atom is a JIT-time operation. The descriptor’s shape, strides, and dtype are baked in. At runtime, each TMA issue is just “load the tile at this coordinate from GMEM into this SMEM slot, using the metadata in the atom.”
You also typically prefetch the descriptor into the SM’s caches before the first issue, to avoid descriptor-fetch latency on the first load:
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q)
FA3 does this for all three operands (Q, K, V) on warp 0 before the main loop starts.
Partitioning for TMA: tma_partition
Once the TMA atom is built, partition the source and destination tensors for the issue. The function is analogous to the MMA partition_* calls, but specific to TMA:
tQsQ, tQgQ = cute.nvgpu.cpasync.tma_partition(
tma_atom_q,
0, # multicast slot index (0 = no multicast)
cute.make_layout(1), # CTA layout (trivial 1-element = single CTA)
cute.group_modes(sQ_full, 0, 2), # SMEM dest, tile interior grouped
cute.group_modes(gQ_tiles, 0, 2), # GMEM source, tile interior grouped
)
A few moving pieces here:
- The
0and trivial 1-element layout say “no multicast, this single CTA only.” Multicast lets one TMA broadcast the same tile to multiple CTAs, which is useful for shared inputs but not what we use in our basic FA3 kernel. -
cute.group_modes(tensor, 0, 2)collapses the first two modes of a tensor into one. TMA treats “the tile” as a single opaque blob (the hardware knows how to walk it via the descriptor), so we collapse the tile-interior modes into one mode. The outer modes remain to identify which tile in the larger source.
A quick aside on pipelining, because the destination shape mentions stages. Loading a tile from GMEM is slow relative to consuming it with WGMMA, so a serial “load one tile, compute on it, load the next” pattern leaves the Tensor Core idle most of the time. The standard fix is to keep multiple loads in flight at once: allocate several SMEM slots, called stages, and arrange them as a circular buffer. While the consumer (WGMMA) is computing on stage \(i\), the producer (TMA) is already filling stages \(i+1\), \(i+2\), and so on. The SMEM destination layout therefore has an extra outermost mode counting the stages, and a tile loaded at iteration \(j\) lands in slot \(j \bmod \text{n\_stages}\). K3 covers the producer-consumer coordination and the mbarrier mechanics that make this work; here we just need to know that the stage dimension exists in the SMEM destination layout.
The return values follow the same naming convention we saw with MMA partitioning:
-
tQsQ— TMA-partitioned SMEM destination. Shape(tma_unit, n_stages)wheren_stagesis the pipeline-stage count if the SMEM has a circular buffer. -
tQgQ— TMA-partitioned GMEM source. Shape(tma_unit, tile_coord)where the outer mode identifies which tile to load.
The naming convention extends from MMA partitioning: the first letter t for thread-partitioned, the second identifies which operation’s partitioner produced it (here Q because we’re partitioning for the Q tensor’s TMA copy), and the third indicates memory space (s for SMEM, g for GMEM). So tQsQ reads as “TMA-partition for Q, in SMEM” and tQgQ as “TMA-partition for Q, in GMEM.”
Issuing a copy: cute.copy
With the atom built and the source and destination partitioned, you issue the copy with cute.copy:
cute.copy(
tma_atom_q, # the atom
tQgQ[(None, bidx_m)], # source: tile at GMEM coord bidx_m
tQsQ[(None, q_producer_state.index)], # dest: SMEM at stage q_producer_state.index
tma_bar_ptr=q_pipeline.producer_get_barrier(q_producer_state), # mbarrier for completion
)
This is analogous in spirit to cute.gemm. The atom carries the operation metadata; the partitioned tensors say which tile and which destination; cute.copy produces the actual TMA instruction.
A few things to register:
- Source indexing.
tQgQ[(None, bidx_m)]slices the GMEM-source partition at the tile coordinate.bidx_mis the M-tile index for this CTA. TheNoneleaves the TMA-unit mode unsliced (TMA always processes the whole tile-unit). - Destination indexing.
tQsQ[(None, q_producer_state.index)]slices the SMEM destination at a specific pipeline stage. We’re writing into stagestate.indexof the SMEM circular buffer. - Completion barrier.
tma_bar_ptris an mbarrier (memory barrier) that the TMA hardware will arrive on when the copy completes. The consumer (the WGMMA that wants this data) waits on this barrier before reading SMEM. The whole pipeline machinery, the producer-consumer dance, the circular buffer, the barrier arrivals, lives with K3. - Asynchrony.
cute.copyis asynchronous, likecute.gemm. The instruction issues, the function returns, and the actual copy proceeds on the TMA hardware in the background.
You typically issue copies on a single “producer” warp (e.g., warp 0) while the rest of the warpgroup acts as consumers. The producer issues the next batch of TMAs; the consumers wait for prior batches to land and then run WGMMA on them. This producer-consumer pattern, the pipeline-stage management, and the mbarrier mechanics are all covered in K3.
What cute.copy and TMA atoms do not do here
Mirroring the cute.gemm story, a few things are easy to assume but are not the dispatch primitive’s job:
- They do not synchronize with consumers. The mbarrier passed in tracks completion, but the consumer’s
waiton that barrier is separate code. - They do not allocate SMEM. That’s done once when the kernel allocates its
SharedStorage.cute.copyreads/writes the existing SMEM. - They do not manage pipeline state. The producer/consumer state advancement, the producer_acquire / consumer_release dance, is wrapped by a pipeline object (covered in K3).
- They do not handle multi-tile sweeps. One
cute.copyissues one TMA: one tile, one destination. The kernel loops over tiles explicitly, issuing one TMA per loop iteration.
A note on swizzle and SMEM layout
We’ve mentioned that the SMEM destination layout has a “swizzle pattern.” Briefly: SMEM has 32 banks, and naive layouts cause bank conflicts when many threads access SMEM simultaneously. Swizzling permutes addresses (via XOR-based bit manipulation) so accesses scatter across banks. The TMA descriptor and the SMEM layout must agree on the swizzle pattern: TMA writes bytes in the swizzled order, and downstream reads (by WGMMA) expect them in that order.
CuTe provides pre-built SMEM layouts that satisfy the swizzle requirements automatically. In FA3 these come from sm90_utils.make_smem_layout_b(...) and similar helpers. The pattern matters when the SMEM is consumed by WGMMA, because WGMMA’s expected layout has a specific swizzle. This is K3 (TMA side) and K4 (WGMMA side) territory; here we just register that swizzle exists and that the helpers produce the right layouts automatically.
The full CuTe vocabulary, recap
We’ve now built up everything CuTe needs to express a kernel: layouts, tensors, partitioning, and the two dispatch primitives cute.gemm and cute.copy. To close this section, here’s the full vocabulary in one place, with what each piece does:
- Layouts are coordinate-to-offset functions, JIT-time constants, the planning step.
- Tensors are
(iterator, layout)pairs that attach a pointer to a layout; indexing is where computation meets memory. - TiledMMA describes how a matmul maps to the hardware (MMA atom + atom layout + TV layouts).
-
partition_A/B/Cproduce per-thread re-layouts of operand tiles, no data motion. -
make_fragment_A/Bwrap partitioned tensors in the WGMMA-encoded form (no allocation for SMEM sources; the Tensor Core reads SMEM via a descriptor). -
make_fragment_Cactually allocates per-thread registers for the accumulator. -
cute.gemm(tiled_mma, acc, A, B, acc)dispatches the matmul, expanding into one WGMMA per atom. - TMA atom wraps a single TMA copy operation with its descriptor and metadata.
-
tma_partitionprepares source and destination tensors for TMA dispatch. -
cute.copyissues the actual TMA copy (asynchronous, completion signaled on an mbarrier). - Identity tensors are coordinate-reporting views, used as probes to figure out where partitioned values live in the original tile.
-
cute.flat_dividetiles a tensor into chunks for per-block work. - TV layouts (
tv_layout_A/B/C) encode how threads and per-thread values map to positions in the tile, scattered in hardware-defined patterns.
That’s the full vocabulary. Every kernel from K1 through K8 uses this same set of primitives. What changes from kernel to kernel is which operands come from SMEM versus registers, how aggressively the loads are pipelined, whether the kernel splits warps into specialized roles, and what precisions it uses, but the building blocks are the same. We now have everything we need to start writing kernels.