K1 in CuTe [FA3]

This is the in-depth companion to the K1 section of the FA3 worklog. Refer back there for the algorithm, motivation, and results. This post covers the CuTe mechanics: how the three kernels are written in CuTe Python DSL, what primitives they use, and how the harness compiles and dispatches them.

The kernel file is kernels/k1.py in the project repo. Three CuTe classes (K1Score, K1Softmax, K1Output) plus a compile_kernel / run_kernel dispatcher pair.

The inner kernel: K1Score

The other two kernels follow the same shape, so we’ll read K1Score carefully and trust the structure for the others.

@cute.kernel
def kernel(self, mQ, mK, mS, scale: cutlass.Float32):
    bx, by, bz = cute.arch.block_idx()
    tx, ty, _ = cute.arch.thread_idx()

    j = bx * SCORE_BLOCK_X + tx
    i = by * SCORE_BLOCK_Y + ty
    b = bz

    if i < self.N and j < self.N:
        # Causal mask: kv position > query position => -inf.
        if cutlass.const_expr(self.causal) and j > i:
            mS[i, j, b] = -cutlass.Float32.inf
        else:
            gQ = mQ[(None, None, b)]
            gK = mK[(None, None, b)]
            dot = cutlass.Float32(0.0)
            for k in cutlass.range_constexpr(self.d):
                dot = dot + gQ[i, k].to(cutlass.Float32) * gK[j, k].to(cutlass.Float32)
            mS[i, j, b] = dot * scale

Walking through the CuTe-specific bits:

cute.arch.block_idx() and cute.arch.thread_idx() are the CuTe analogues of CUDA’s blockIdx and threadIdx. They return tuples; we destructure them into (bx, by, bz) and (tx, ty, _). Per-element work is then standard CUDA: each thread computes its global \((i, j, b)\) coordinate from block and thread indices.

mQ[(None, None, b)] is CuTe’s slicing pattern: None keeps the mode free, a concrete value fixes the coordinate. So mQ[(None, None, b)] fixes the batch coord at b and leaves the other two free, producing a 2D view gQ of shape (N, d) that points into the same GMEM storage as mQ. No data motion, just a layout reinterpretation. The full mechanics are in the CuTe DSL fundamentals notes.

cutlass.const_expr(self.causal) is the JIT-time-constant marker. The condition is resolved at compile time, not at runtime: the compiler emits a specialized kernel for either causal or non-causal, with no runtime branch. This is the analogue of C++ if constexpr. The same pattern lets the body of the if and the else branches expand into different code paths at JIT time.

cutlass.range_constexpr(self.d) is a fully-unrolled-at-JIT-time loop. Since d is a JIT-time constant (64 or 128 in our sweeps), the \(k\)-axis dot product is emitted as 64 or 128 straight-line FMA instructions, no loop overhead, no runtime branch. The same primitive shows up throughout the worklog whenever a loop’s trip count is known at JIT time.

The dot product itself is plain CUDA semantics: per-thread FP32 accumulator, BF16 GMEM reads cast up with .to(cutlass.Float32), scalar FMAs. No SMEM, no Tensor Cores. This is the absolute floor of GPU compute on H100.

The other two kernels (K1Softmax, K1Output) use the same primitives. K1Softmax reads its row in three passes (max, exp+sum, normalize) using a plain Python for k in range(self.N) with runtime trip count (since N is not JIT-constant per kernel; it varies per sweep config). K1Output mirrors K1Score’s structure with P and V instead of Q and K.

The dispatcher: compile_kernel and run_kernel

The harness calls two functions per kernel file. compile_kernel(B, N, d, causal, tensors) does the one-time JIT work and returns an opaque handle; run_kernel(handle, tensors) launches the kernel using that handle on the current stream. K1’s implementation:

def compile_kernel(B, N, d, causal, tensors):
    score = K1Score(B, N, d, causal)
    softmax = K1Softmax(B, N)
    output = K1Output(B, N, d)
    scale = cutlass.Float32(1.0 / math.sqrt(d))
    stream = get_cuda_stream()

    s_storage, s_cute = _make_S_tensor(B, N)

    score_compiled = cute.compile(
        score, tensors["q_cute"], tensors["k_cute"], s_cute, scale, stream,
    )
    softmax_compiled = cute.compile(softmax, s_cute, stream)
    output_compiled = cute.compile(
        output, s_cute, tensors["v_cute"], tensors["o_cute"], stream,
    )

    return (score_compiled, softmax_compiled, output_compiled,
            scale, s_storage, s_cute)

Two things to call out:

cute.compile(kernel_obj, *args) is where the JIT happens. The Python process introspects the @cute.jit callable, builds an IR, lowers it to PTX/SASS, and returns a compiled object that can be invoked like a function. The argument shapes and dtypes at compile time fix the specialization: a kernel compiled for \(N = 512, d = 64\) is a different binary from one compiled for \(N = 1024, d = 128\). The harness handles this by calling compile_kernel once per sweep config.

The \(S\) buffer is allocated inside compile_kernel, not per launch. This matters: the buffer is \(B \times N \times N \times 4\) bytes, which is up to several GB at large \(N\). If we allocated it on every benchmark iteration, the allocation cost would dominate timing and we’d risk OOM thrashing. Allocating once and reusing it across the 30 timed iterations of the sweep keeps the measurement focused on actual kernel work.

run_kernel is then trivial:

def run_kernel(compiled_handle, tensors):
    (score_compiled, softmax_compiled, output_compiled,
     scale, s_storage, s_cute) = compiled_handle
    stream = get_cuda_stream()
    score_compiled(tensors["q_cute"], tensors["k_cute"], s_cute, scale, stream)
    softmax_compiled(s_cute, stream)
    output_compiled(s_cute, tensors["v_cute"], tensors["o_cute"], stream)

Three sequential launches on the current CUDA stream. No host-device synchronization between them; the stream serializes the launches automatically.

Why K1 is a good starting point for CuTe

A few observations from getting K1 to compile and run, useful as priors for the harder kernels:

JIT specialization is real. self.causal, self.d, and the block shapes are all JIT-time constants. The compiled kernel has no branches for the causal flag, no loop counter for the \(k\) dot product, and constant-folded block dimensions. This is what makes CuTe Python deliver C++-like performance without C++.

Tensors are just (iterator, layout) pairs. Everything we did to slice and index Q, K, V, S was layout manipulation: no data motion, no allocation. The only actual storage operations are the __init__ calls in make_qkvo that allocate the GMEM buffers, and the one _make_S_tensor for the score matrix.

The kernel body looks like CUDA C++. block_idx, thread_idx, per-element work, runtime conditionals. The differences from a CUDA C++ kernel are minor: const_expr and range_constexpr for JIT-time decisions, dtype casts via .to(...), and cutlass.Float32(...) constructors for typed scalars. If you’ve written CUDA C++ before, the K1 body should read clean. The complexity comes later (K2 onward) when CuTe-specific abstractions like TV layouts and tiled MMAs enter the picture.

For the deeper CuTe machinery (layouts, tensors, partitioning, MMA atoms), refer to the CuTe DSL fundamentals notes. For the next kernel in the worklog, see K2: Tiled online softmax.