Part 4 of How To Scale Your Model (Part 3: Sharding | Part 5: Training)
여기서는 Transformer 아키텍처에 대해 간략히 복습하고, 구체적으로 FLOPs, bytes 및 기타 주요 수치들을 계산하는 방법을 살펴보겠습니다.
번역 안내: 원저자(Jacob Austin)의 허락을 받아 원문을 번역 중입니다.
해당 글의 1인칭은 원문 저자를 지칭합니다.
원문: How to Scale Your Model
번역: 신종훈
다음과 같은 형태의 벡터 \(x\),\(y\)와 행렬 \(A\),\(B\)로 시작하겠습니다:
\[\def \red#1{\textcolor{red}{#1}} \def \green#1{\textcolor{green}{#1}} \def \blue#1{\textcolor{blue}{#1}} \def \purple#1{\textcolor{purple}{#1}} \def \orange#1{\textcolor{orange}{#1}} \def \gray#1{\textcolor{gray}{#1}} \begin{array}{cc} \textrm{array} & \textrm{shape} \\ \hline x & \textrm{[P]} \\ y & \textrm{[P]} \\ A & \textrm{[N P]} \\ B & \textrm{[P M]} \\ \hline \end {array}\]행렬-행렬 곱셈의 경우, 연산량은 3차(\(O(N^3)\))로 확장되는 반면 데이터 전송은 2차(\(O(N^2)\))로만 확장된다는 사실에 주목하세요. 이는 행렬 곱셈의 크기를 키울수록 연산 포화(compute-saturated) 한계에 도달하기가 더 쉬워진다는 것을 의미합니다. 이는 매우 이례적인 현상이며, 우리가 행렬 곱셈이 지배적인 아키텍처를 사용하는 이유를 상당 부분 설명해 줍니다. 즉, 확장에 용이하기 때문입니다!
훈련 중에는 주어진 행렬 곱셈의 결과 자체에는 특별히 신경 쓰지 않고, 그 미분값에 더 관심이 있습니다. 이는 역전파(backpropagation) 중에 훨씬 더 많은 FLOPs를 수행한다는 것을 의미합니다.
B가 더 큰 네트워크의 한 행렬이고 A가 입력 활성화(input activations)이며 C = A B라고 상상해 봅시다. 손실(loss) L의 B에 대한 미분은 연쇄 법칙에 의해 다음과 같이 주어집니다:
\[\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial B} = A^T \left(\frac{\partial L}{\partial C}\right)\]이는 외적(outer product)이며, ($N$ 차원에 대해 축약되므로) 계산하는 데 $2NPM$ FLOPs가 필요합니다. 마찬가지로, 손실의 A에 대한 미분은
\[\frac{\partial L}{\partial A} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial A} = \left(\frac{\partial L}{\partial C}\right) B^T\]이며, dL/dC가 \([N, M]\) 크기의 (코)벡터이므로 다시 $2NPM$ FLOPs입니다. 이 양은 파라미터에 대한 미분은 아니지만, 네트워크의 이전 레이어에 대한 미분을 계산하는 데 사용됩니다(예: 위에서 dL/dC가 dL/dB를 계산하는 데 사용된 것처럼).
이를 모두 더하면, 훈련 중에는 총 6NPM FLOPs가 필요하며, 이는 추론 중의 2NPM에 비해 많습니다: 순방향 패스에서 2NPM, 역방향 패스에서 4NPM. PM이 행렬의 파라미터 수이므로, 이것이 훈련 중 Transformer FLOPs에 대한 유명한 \(6 * \text{num parameters} * \text{num tokens}\) 근사치의 가장 간단한 형태입니다: 각 토큰은 \(6 * \text{num parameters}\) FLOPs를 필요로 합니다. 아래에서 더 정확한 유도를 보여드리겠습니다.
Transformers are the future. Well, they’re the present at least. Maybe a few years ago, they were one of many architectures. But today, it’s worth knowing pretty much every detail of the architecture. We won’t reintroduce the architecture but this blog and the original Transformer paper may be helpful references.
Here’s a basic diagram of the Transformer decoder architecture:
Note [gating einsum]: The diagram above uses a “gating einsums”
Note 2 [MHA attention]: With self-attention, T and S are the same but for cross-attention they may be different. With vanilla Multi-Head Attention (MHA), N and K are the same while for Multi-Query Attention (MQA)
For the below we’re going to compute per-layer FLOPs to avoid having to stick factors of L everywhere.
The MLPs of a Transformer typically consist of 2 input matmuls that are element-wise combined and a single output matmul:
\[\begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \\ \hline \\ A[B,T,\red{D}] \cdot W_{in1}[\red{D}, F] & 6BTDF & DF \\[10pt] A[B,T,\red{D}] \cdot W_{in2}[\red{D}, F] & 6BTDF & DF \\[10pt] \sigma\left(A_{in1}\right)[B,T, F] * A_{in2}[B,T, F] & \gray{O(BTF)} \\[10pt] A[B,T,\red{F}] \cdot W_{out}[\red{F}, D] & 6BTDF & DF \\[10pt] \hline \\ & \approx 18BTDF & 3DF \end{array}\]For the generic grouped-query attention case with different Q and KV head numbers, let us assume equal head dimension H for Q,K,V projections, and estimate the cost of the QKVO matmuls:
\[\begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \\ \hline \\ A[B,T,\red{D}] \cdot W_{Q}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{D}] \cdot W_{K}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{D}] \cdot W_{V}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{N}, \red{H}] \cdot W_{O}[\red{N}, \red{H}, D] & 6BTDNH & DNH \\[10pt] \hline \\ & 12BTD(N+K)H & 2D(N+K)H \end{array}\]The dot-product attention operation is more subtle, effectively being a \(TH \cdot HS\) matmul batched over the \(B\), \(K\) dimensions, a softmax, and a \(TS \cdot SH\) matmul again batched over the \(B\), \(K\) dimensions. We highlight the batched dims in blue:
\[\begin{array}{cc} \textrm{operation} & \textrm{train FLOPs} \\ \hline \\[3pt] Q[\blue{B}, T, \blue{K}, G, \red{H}] \cdot K[\blue{B}, S, \blue{K}, \red{H}] & 6BTSKGH = 6BTSNH \\[3pt] \textrm{softmax}_S \;\; L[B, T, S, K, G] & \gray{O(BTSKG) = O(BTSN)} \\[3pt] S[\blue{B}, T, \red{S}, \blue{K}, G] \cdot V[\blue{B}, \red{S}, \blue{K}, H] & 6BTSKGH = 6BTSNH \\[3pt] \hline \\ & \approx 12BTSNH = 12BT^2NH \\ \end{array}\]There are several other operations happening in a Transformer. Layernorms are comparatively cheap and can be ignored for first-order cost estimates. There is also the final enormous (though not per-layer) unembedding matrix multiply.
\[\begin{array}{ccc} \textsf{operation} & \textsf{train FLOPs} & \textsf{params} \\ \hline \\ \textrm{layernorm}_D \;\; A[B,T,\red{D}] & \gray{O\left(BTD\right)} & \gray{D} \\[10pt] A[B,T,\red{D}] \cdot W_{unembed}[\red{D}, V] & 6BTDV & DV \\ \end{array}\]If we neglect the cost of dot-product attention for shorter-context training, then the total FLOPs across all layers is
\[\begin{align*} (18BTDF + 12BTD(N+K)H)L = 6 *BT * (3DF + 2D(N+K)H)L \\ = 6 * \textrm{num tokens} * \textrm{parameter count} \end{align*}\]Leading to a famous rule of thumb for estimating dense Transformer FLOP count, ignoring the attention FLOPs. (Unembedding is another simple matmul with $6BSDV$ FLOPs and $DV$ params, and follows the same rule of thumb.)
If we do account for dot-product attention above and assume \(F=4D\), \(D=NH\) (as is typical) and \(N=K\):
\[\small{\frac{\textrm{attention FLOPs}}{\textrm{matmul FLOPs}} = \frac{12BT^2NH}{18BTDF + 24BTDNH} = \frac{12BT^2D}{4*18 BTD^2 + 24 BTD^2} = \frac{12BT^2D}{96 BTD^2} = \frac{T}{8D}}\]So the takeaway is that dot-product attention FLOPs only become dominant during training once T>8D. For D ~ 8k, this would be ~64K tokens. This makes some sense, since it means as the MLP size increases, the attention FLOPs become less critical. For large models, the quadratic cost of attention is not actually a huge obstacle to longer context training. However, for smaller models, even e.g. Gemma-27B, D=4608 which means attention becomes dominant around 32k sequence lengths. Flash Attention also helps alleviate the cost of long-context, which we discuss briefly in Appendix A.
We’d be remiss not to briefly discuss Mixture of Experts (MoE) models
Compared to a dense model, an MoE introduces new comms, primarily two AllToAlls (one before and one after the MoE block) that route tokens to the correct expert and bring them back to their home device.
Backpropagation as an algorithm trades memory for compute. Instead of a backward pass requiring \(O(n_\text{layers}^2)\) FLOPs, it requires \(O(n_\text{layers})\) memory, saving all intermediate activations generated during the forward pass. While this is better than quadratic compute, it’s incredibly expensive memory-wise: a model with \(B * T=4M\) (4M total tokens per batch), L=64, and D=8192 that avoids all unnecessary backward pass compute would have to save roughly \(2 * 20 * B * T * D * L = 84TB\) of activations in bfloat16. 20 comes from (roughly) counting every intermediate node in the Transformer diagram above, since e.g.
\[f(x) = \exp(g(x))\] \[\frac{df}{dx} = \exp(g(x)) \cdot \frac{dg}{dx}\]so to avoid recomputing we need to save \(g(x)\) and \(\exp(g(x))\) from the forward pass. To avoid saving this much memory, we can choose to only save some fraction of the intermediate activations. Here are a few strategies we use.
This by no means comprehensive. When using JAX, these are typically controlled by jax.remat
/jax.checkpoint
(you can read more here).
As we’ll see in Section 7, LLM inference has two key parts, prefill and generation.
Each KV cache is then effectively an array of size $[2, S, L, K, H]$ where the 2 accounts for the keys and values. This is quite large! The total size of the Key-Value cache in int8 is $2SLKH$. For a moderately-sized model with 8k context length, 64 layers, and $KH = NH = D = 8192$, this is $2 \cdot 8192 \cdot 64 \cdot 8192 = 8\text{GiB}$. You can see why we would want to use GMQA with $K \ll N$.
Component | Params per layer | Training FLOPs per layer |
---|---|---|
MLP | 3DF | 18BTDF |
Attention | 4DNH | 24BTDNH + 12BT2NH |
Other | D | BTD |
Vocab | DV (total, not per-layer) | 12BTDV |
Question 1: How many parameters does a model with $D=4096$, $F=4 \cdot D$, $V=32,000$, and $L=64$ have? What fraction of these are attention parameters? How large are our KV caches per token? You can assume $N\cdot H=D$ and multi-head attention with int8 KVs.
512kB / token
.Question 2: How many total FLOPs are required to perform A[BX, DY] *D W[DY, F] on {‘X': 4, ‘Y': 8, ‘Z': 4}
. How many FLOPs are performed by each TPU?
The total “theoretical” FLOPs of the operation is \(2 \cdot B \cdot D \cdot F\). However, because the computation isn’t sharded across the Z dimension, we’re actually doing Z extra FLOPs, meaning \(2 \cdot B \cdot D \cdot F \cdot Z\) total FLOPs. Since the computation is sharded across the other dimensions, the total per-device is roughly \(2 \cdot B \cdot D \cdot F / (X \cdot Y)\).
Question 3: How many FLOPs are involved in performing $A[I,J,K,L] * B[I,J,M,N,O] \rightarrow C[K,L,M,N,O]$?
Following the rule above, we have I and J as contracting dimensions and K, L, M, N, and O as non-contracting dimensions. We have no “batching dimensions”, so this is just \(2 \cdot I \cdot J \cdot K \cdot L \cdot M \cdot N \cdot O\), the sum of all the axes. If we had a shared axis, it would only be counted once.
Question 4: What is the arithmetic intensity of self-attention (ignoring the Q/K/V/O projections)? Give the answer as a function of the Q and KV lengths T and S. At what context length is attention FLOPs-bound? Given the HBM bandwidth of our TPUs, plot the effective relative cost of attention to the FFW block as the context length grows.
Self-attention requires loading the \(Q\), \(K\), and \(V\) activations, then computing \(\text{softmax}(Q \cdot K) \cdot V\), then writing the result back to HBM. This will be done with Flash Attention so there are some caveats to this math, but basically in bf16 self-attention performs
\[\text{Q[B,T,N,H]} \rightarrow_\text{reshape} \text{Q[B, T, K, G, H]} \cdot \text{K[B, S, K, H]} \rightarrow \text{O[B, T, S, K, G]}\] \[U=\text{softmax}_S(\text{O[B, T, S, K, G]})\] \[\text{U[B, T, S, K, G]} \cdot \text{V[B, S, K, H]} \rightarrow \text{X[B, T, K, G, H]}\]So our total bytes is \(2 * \text{sizeof}(Q) + 2 * \text{sizeof(K or V)} = 4BTNH + 4BSKH = 4BHK * (TG + S)\), total FLOPs is \(4BTSNH + O(BTSN)\) and the arithmetic intensity is \(4BTSKGH / (4BHK * (TG + S))\).
So basically, during prefill we have \(S=T\) so we have an arithmetic intensity of \(4BT^2KGH / 4BHKT \cdot (G+1) = TG/(G + 1) = O(T)\). During generation, \(T=1\) so we have \(4BSKGH / (4BHK \cdot (G + S)) = SG / (G + S) \rightarrow G\) assuming \(S\) is very large. Depending on how you interpret the question, during prefill or training self-attention is compute bound at S=240 assuming no sequence sharding. During generation, we are never compute bound because \(G\) is small. Nonetheless, however, you can see that increasing \(G\) leads to us being closer to compute bound.
Question 5: At what sequence length are self-attention FLOPs equal to the QKVO projection FLOPs?
This is purely a question of when \(24BTDNH == 12BT^2NH\). Simplifying we get \(2D = T\), so e.g. for \(D=4096\), this is \(8192\). This tells us that for most reasonable context lengths, matmul FLOPs are greater.
Question 6: Say we only save the output of each of the 7 main matmuls in a Transformer layer during our forward pass (Q, K, V, O + the three FFW matrices). How many extra FLOPs do we need to “rematerialize” during the backwards pass?
Question 7: DeepSeek v3 says it was trained for 2.79M H800 hours on 14.8T tokens (source). Given that it has 37B activated parameters, roughly what hardware utilization did they achieve? Hint: note that they used FP8 FLOPs without structured sparsity.
From the spec sheet here, we find 3,026 TFLOPs/s of FP8 performance with sparsity, or typically half this (1.513e15
FLOPs/s) without sparsity. 2.79M H800 hours means 2.79e6 * 1.513e15 * 60 * 60 = 1.52e25
total FLOPs. Given the activated parameter count of 37B, this training run should have used about 6 * 37e9 * 14.8e12 = 3.3e24
FLOPs. That means the FLOPs utilization is about 3.3e24 / 1.52e25 = 21.7%
.
Question 8: Mixture of Experts (MoE) models have $E$ copies of a standard dense MLP block, and each token activates $k$ of these experts. What batch size in tokens is required to be compute-bound for an MoE with weights in int8 on TPU v5e? For DeepSeek, which has 256 (routed) experts and $k=8$, what is this number?
Because we have $E$ copies of each expert, in int8, we need to load $E \cdot D \cdot F$ bytes. Because each token activates $k$ experts, we have $2\cdot k \cdot B \cdot D \cdot F$ FLOPs. To be compute-bound with bfloat16 FLOPs, we need an arithmetic intensity over 240 which happens when $(2\cdot k \cdot BDF) / EDF > 240$ or $k \cdot B / E > 120$.
Therefore, we need $B > 120 \cdot E / k$ to be compute bound. For DeepSeek, this gives us $B > 120 \cdot 256 / 8 = 3840$. This is a remarkably large batch size at generation time.
The traditional objection to scaling Transformers to very long context is that the attention FLOPs and memory usage scale quadratically with context length. While it’s true that the attention QK product has shape $[B, S, T, N]$ where B is the batch size, S and T are the Q and K sequence dims, and N is the number of heads, this claim comes with some serious caveats:
This second observation was first made by Rabe et al. 2021 and later in the Flash Attention paper (Dao et al. 2022). The basic idea is to compute the attention in chunks of K/V, where we compute the local softmax and some auxiliary statistics, then pass them onto the next chunk which combines them with its local chunk. Specifically, we compute
With these, we can compute the new max, the new running sum, and the new output with only a constant amount of memory. To give a sketchy description of how this works, attention is roughly this operation:
\[\text{Attn}(Q, K, V) = \sum_i \frac{\exp(Q \cdot K_i - \max_j Q \cdot K_j) V_i}{\sum_l \exp(Q \cdot K_l - \max_j Q \cdot K_j)}\]The max is subtracted for numerical stability and can be added without affecting the outcome since \(\sum_i \exp(a_i + b) = \exp(b) \sum \exp(a)\). Looking just at the denominator above, if we imagine having two contiguous chunks of key vectors, \(K^1\) and \(K^2\) and we compute the local softmax sums \(L^1\) and \(L^2\) for each
\[L^1 = \sum_i \exp(Q \cdot K_i^1 - \max_j Q \cdot K_j^1)\] \[L^2 = \sum_i \exp(Q \cdot K_i^2 - \max_j Q \cdot K_j^1)\]Then we can combine these into the full softmax sum for these two chunks together by using
\[L^\text{combined} = \exp(M^1 - \max(M^1, M^2)) \cdot L^1 + \exp(M^2 - \max(M^1, M^2)) \cdot L^2\]where
\[M^1 = \max_j Q \cdot K_j^1 \text{ and } M^2 = \max_j Q \cdot K_j^2\]This can be done for the full softmax as well, giving us a way of accumulating arbitrarily large softmax sums. Here’s the full algorithm from the Flash Attention paper.
From a hardware standpoint, this lets us fit our chunk of Q into VMEM (what the algorithm above calls on-chip SRAM) so we only have to load the KV chunks on each iteration, reducing the arithmetic intensity. We can also keep the running statistics in VMEM.
One last subtle point worth emphasizing is an attention softmax property that’s used to make the Flash VJP (reverse mode derivative) calculation practical for training. If we define an intermediate softmax array as:
\[S_{ij} = \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_j}}\]In attention, we obtain dS from reverse-mode dO and V arrays:
\[dS_{ij} = dO_{id} \cdot_d V_{jd} = \sum_d dO_{id} V_{jd}\]During the backpropagation of this gradient to Q and K
\[d(q_i \cdot k_j) = (dS_{ij} - S_{ij} \cdot_j dS_{ij}) S_{ij}\]We exploit an identity that allows us to exchange a contraction along the large key length dimension with a local contraction along the feature depth dimension.
\[\begin{align*} S_{ij} \cdot_j dS_{ij} &= \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} \sum_d dO_{id} V_{jd} \\ &= \sum_d dO_{id} \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} V_{jd} \\ &= \sum_d dO_{id} O_{id} \\ &= dO_{id} \cdot_d O_{id} \end{align*}\]This replacement is crucial for being able to implement a sequence-block local calculation for the VJP, and enables further clever sharding schemes like ring attention.