Transformer Block FLOPs & Parameters Calculations
Notation
- \(B\): Batch size
- \(S\): Sequence length
- \(D\): Hidden dimension / Model dimension
- \(N\): Number of attention heads
- \(H\): Head dimension \((H = D/N)\)
- \(F\): Feed-forward hidden dimension (typically \(4D\))
- \(S_1\)/\(S_2\): Sequence length in attention context, \(S_1\)=of query & \(S_2\)=of key (often \(S\))
- \(\text{SiLU}\): Sigmoid Linear Unit activation (also known as Swish)
- \(\text{RoPE}\): Rotary Position Embeddings
- \([X \times Y]\): Matrix/tensor dimensions
- \(@\): Matrix multiplication operator
- \([X]\cdot[Y]\): Elementwise multiplication/dot-product
- \(L\): Number of layers
Transformer Block Operations
Input
\[X = [B \times S \times D] \quad \text{(Input tensor)}\]RMS Normalization 1
\[\text{Rmsnorm}(x) = \text{normed\_x} * \text{gains}, \quad [B \times S \times D] \cdot [D] \rightarrow \quad [B \times S \times D]\]Multi-Head Attention (MHA)
Attention applied on Rmsnorm output:
-
(a) \(Q = W_q @ x\), \(\quad [D \times D] @ [B \times S \times D] \rightarrow [B \times S \times D]\)
-
(b) \(K = W_k @ x\), \(\quad [D \times D] @ [B \times S \times D] \rightarrow [B \times S \times D]\)
-
(c) \(V = W_v @ x\), \(\quad [D \times D] @ [B \times S \times D] \rightarrow [B \times S \times D]\)
-
(d) Rearrange for multi-head attention(\(Q, K, V\)): \([B \times S \times D] \rightarrow [B \times S \times (N \cdot H)] \rightarrow [B \times N \times S \times H]\)
-
(e) Apply RoPE to \(Q\) and \(K\) (element-wise) \(\rightarrow [B \times N \times S \times H]\)
-
(f) Scaled Dot-Product Attention with Causal Mask: \(\begin{align} QK^T &= [B \times N \times S_1 \times H] @ [B \times N \times S_2 \times H]^T \rightarrow [B \times N \times S_1 \times S_2] \\ \text{output} &= \text{attn-weights} @ V = [B \times N \times S_1 \times S_2] @ [B \times N \times S_2 \times H] \rightarrow [B \times N \times S_1 \times H] \end{align}\)
-
(g) Rearrange back: \([B \times N \times S \times H] \rightarrow [B \times S \times (N \cdot H)] \rightarrow [B \times S \times D]\)
-
(h) Output Projection: \(X_{\text{out}} = W_o @ X_{\text{attn}}, \quad [D \times D] @ [B \times S \times D] \rightarrow [B \times S \times D]\)
RMS Normalization 2
\[\text{Rmsnorm}_2 = [B \times S \times D] \cdot [D] \rightarrow [B \times S \times D]\]Feed-Forward Network (FFN)
-
Gate projection: \(\text{Gate} = W_1 @ X, \quad [D \times F] @ [B \times S \times D] \rightarrow [B \times S \times F]\)
-
Activation: \(\text{Gate} = \text{SiLU}(\text{Gate}) \quad \text{(element-wise activation)}\)
-
Linear projection: \(\text{Linear} = W_3 @ X, \quad [D \times F] @ [B \times S \times D] \rightarrow [B \times S \times F]\)
-
Gated multiplication: \(\text{Gated} = \text{Gate} \otimes \text{linear} \quad \text{(element-wise multiplication)}\)
-
Output projection: \(\text{Net} = W_2 @ \text{Gated}, \quad [F \times D] @ [B \times S \times F] \rightarrow [B \times S \times D]\)