Part 5 of How To Scale Your Model (Part 4: Transformers | Part 6: Training LLaMA)
여기서는 LLM 훈련 중에 사용되는 4가지 주요 병렬 처리 방식인 데이터 병렬 처리(data parallelism), 완전 샤딩된 데이터 병렬 처리(FSDP), 텐서 병렬 처리(tensor parallelism), 그리고 파이프라인 병렬 처리(pipeline parallelism)에 대해 논의합니다. 각각에 대해 어느 시점에서 통신에 의해 병목 현상이 발생하는지 계산합니다.
번역 안내: 원저자(Jacob Austin)의 허락을 받아 원문을 번역 중입니다.
해당 글의 1인칭은 원문 저자를 지칭합니다.
원문: How to Scale Your Model
번역: 신종훈
“모델 스케일링”의 목표는 훈련이나 추론에 사용되는 칩의 수를 늘리면서 처리량을 비례적으로, 즉 선형적으로 증가시키는 것입니다(이를 strong scaling이라고 합니다). 단일 칩에서의 성능은 메모리 대역폭과 FLOPs 간의 트레이드오프에 달려있지만, 클러스터 수준에서의 성능은 칩 간 통신을 유용한 FLOPS와 중첩시켜 숨기는 것에 달려있습니다. 이는 간단하지 않은데, 칩 수를 늘리면 통신 부하가 늘어나는 동시에 이를 숨기는 데 사용할 수 있는 디바이스당 연산량은 줄어들기 때문입니다. 섹션 3에서 보았듯이, 샤딩된 행렬 곱셈은 종종 비싼 AllGather나 ReduceScatter를 필요로 하며, 이는 TPU가 유용한 작업을 하는 것을 막을 수 있습니다. 이 섹션의 목표는 이것들이 언제 너무 비싸지는지 알아내는 것입니다.
이 섹션에서는 네 가지 일반적인 병렬 처리 방식에 대해 논의합니다: (순수) 데이터 병렬 처리(data parallelism), 완전 샤딩된 데이터 병렬 처리(FSDP / ZeRO sharding), 텐서 병렬 처리(tensor parallelism)(모델 병렬 처리라고도 함), 그리고 (간략하게) 파이프라인 병렬 처리(pipeline parallelism)입니다. 각각에 대해 어떤 통신 비용이 발생하고 어느 시점에서 그 비용이 연산 비용에 병목이 되기 시작하는지 보여줄 것입니다.
이 섹션 전체에서 계산을 단순화하기 위해 다음 표기법을 사용할 것입니다.
| Notation | Meaning (model parameters) |
|---|---|
| D | dmodel ( hidden dimension/residual stream dim) |
| F | dff (feed-forward dimension) |
| B | Batch dimension (배치의 토큰 수; 디바이스별이 아닌 전체) |
| T | Sequence length |
| L | Number of layers in the model |
| Notation | Meaning (hardware characteristic) |
|---|---|
| C | FLOPS/s per chip |
| W | Network bandwidth (bidirectional, often subscripted as e.g. $W_{\text{ici}}$ or $W_{\text{dcn}}$ |
| X | Number of chips along mesh axis X |
| Y | Number of chips along an alternate mesh axis, labeled Y |
| Z | Number of chips along a third mesh axis, labeled Z |
단순화를 위해, Transformer를 MLP 블록의 스택으로 근사하겠습니다. 섹션 4에서 보았듯이 어텐션은 대규모 모델에서 FLOPs의 비교적 적은 부분을 차지하기 때문입니다. 또한 게이팅 matmul을 무시하고 각 레이어에 대해 다음과 같은 간단한 구조만 남겨두겠습니다:
bf16[D, F] (up-projection) 과 Wout: bf16[F, D] (down-projection), 그리고 입력 In: bf16[B, D] 의 스택으로 취급합니다.Forward pass: compute Loss[B]
Backward pass: compute dWout[F, D], dWin[D, F]
통신이 추가된 알고리즘과 비교하기 위해 이를 제공합니다.
다음은 우리가 논의할 4가지 병렬 처리 방식입니다. 각 방식은 위 다이어그램의 In, Win, Wout, Out에 대한 샤딩에 의해 고유하게 정의되는 것으로 생각할 수 있습니다.
1. Data parallelism: 활성화는 배치에 따라 샤딩되고, 파라미터와 옵티마이저 상태는 각 디바이스에 복제됩니다. 통신은 역전파 중에만 발생합니다.
\[\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]\]2. Fully-sharded data parallelism (FSDP or ZeRO-3): 활성화는 배치에 따라 샤딩되며(순수 데이터 병렬 처리처럼), 파라미터는 동일한 메시 축을 따라 샤딩되고 순방향 패스에서 사용되기 직전에 AllGather됩니다. 옵티마이저 상태도 배치에 따라 샤딩됩니다. 중복 메모리를 줄입니다.
\[\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]\]3. Tensor parallelism (Megatron sharding 또는 model parallelism이라고도 함): 활성화는 D ($d_\text{model}$)에 따라 샤딩되고, 파라미터는 F ($d_{ff}$)에 따라 샤딩됩니다. 각 블록 전후에 활성화를 AllGather하고 ReduceScatter합니다. FSDP와 호환됩니다.
\[\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]\]4. Pipeline parallelism: 가중치는 레이어 차원에 따라 샤딩되고, 활성화는 마이크로배치되어 레이어 차원에 따라 굴러갑니다(rolled). 파이프라인 단계 간 통신은 최소화됩니다(단일 홉으로 활성화만 이동). 표기법을 남용하자면:
\(\text{In}[L_Z, B, D][i] \cdot_D W_\text{in}[L_Z, D, F][i] \cdot_F W_\text{out}[L_Z, F, D][i] \rightarrow \text{Out}[L_Z, B, D][i]\) \(\text{In}[L_Z, B, D][i] \cdot_D W_\text{in}[L_Z, D, F][i] \cdot_F W_\text{out}[L_Z, F, D][i] \rightarrow \text{Out}[L_Z, B, D][i]\)
Syntax: \(\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]\)
모델이 아주 작은 배치 크기(>240 토큰, compute-bound가 되도록)라도 단일 칩에 맞는다면, 항상 단순한 데이터 병렬 처리를 사용해야 합니다. 순수 데이터 병렬 처리는 TPU 수가 배치 크기보다 작은 한 활성화를 원하는 수의 TPU에 걸쳐 분할합니다. 순방향 패스에는 통신이 포함되지 않지만, 모든 단계의 끝에서 각 TPU는 파라미터를 업데이트하기 전에 로컬 그라디언트를 동기화하기 위해 AllReduce를 수행합니다.
Pure Data Parallelism Algorithm:
Forward pass: compute Loss[BX]
Backward pass: compute dWout[F, D], dWin[D, F]
손실 함수의 세부 사항은 무시하고 $\text{Tmp} = W_\text{in} \cdot \text{In}$으로 축약합니다. 최종 손실은 평균 AllReduce(Loss[BX])이지만, 가중치 그라디언트를 평균화할 때 역방향 패스에서만 AllReduce를 계산하면 됩니다.
순방향 패스에는 통신이 없다는 점에 유의하세요 — 모두 역방향 패스에 있습니다! 역방향 패스는 또한 AllReduce가 “critical path”에 있지 않다는 훌륭한 속성을 가지고 있습니다. 즉, 각 AllReduce는 편리할 때 수행할 수 있으며 후속 작업을 수행하는 것을 차단하지 않습니다. 총 통신 비용이 총 연산 비용을 초과하면 전체 통신 비용이 여전히 병목이 될 수 있지만, 구현 관점에서는 훨씬 더 관대합니다. 모델/텐서 병렬 처리에는 이 속성이 없다는 것을 보게 될 것입니다.
왜 이렇게 할까요? 순수 데이터 병렬 처리는 배치 차원에 걸쳐 활성화를 분할하여 활성화 메모리 압박을 줄여줍니다. 배치 차원을 분할할 칩이 더 많다면 거의 임의로 배치 크기를 늘릴 수 있습니다. 특히 훈련 중에 활성화가 메모리 사용량을 지배하는 경우가 많으므로 이는 매우 유용합니다.
왜 이렇게 하지 않을까요? 순수 데이터 병렬 처리는 모델 파라미터나 옵티마이저 상태로 인한 메모리 압박을 줄이는 데 아무런 도움이 되지 않습니다. 즉, 파라미터 + 옵티마이저 상태가 단일 TPU에 맞지 않는 대규모의 흥미로운 모델에는 거의 유용하지 않습니다. 규모 감각을 주기 위해, 파라미터를 bf16으로, 옵티마이저 상태를 fp32로 Adam
Takeaway: Adam과 순수 데이터 병렬 처리로 훈련할 수 있는 가장 큰 모델은 \(\text{num_params} = \text{HBM per device} / 10\)입니다. TPU v5p의 경우 대략 9B 파라미터입니다.
실제 모델 훈련에 유용하게 사용하려면 모델 파라미터나 옵티마이저를 적어도 부분적으로 샤딩해야 합니다.
언제 통신에 의해 병목이 발생할까요? 위에서 볼 수 있듯이 레이어당 두 개의 AllReduce가 있으며, 각각의 크기는 \(2DF\) (bf16 가중치용)입니다. 데이터 병렬 처리는 언제 통신 병목이 될까요?
위의 표에서와 같이, $C$ = 칩당 FLOPs, $W_{\text{ici}}$ = 양방향(bidirectional) 네트워크 대역폭, $X$ = 배치가 분할된 샤드 수
Communication time: 이전 섹션에서 1D 메시에서 AllReduce를 수행하는 데 필요한 시간은 AllReduce되는 배열의 총 바이트 수와 ICI 대역폭 $W_\text{ici}$에만 의존한다는 것을 알고 있습니다. 구체적으로 AllReduce 시간은 $2 \cdot \text{total bytes} / W_\text{ici}$입니다. $W_\text{in}$과 $W_\text{out}$ 모두에 대해 AllReduce가 필요하므로 레이어당 2개의 AllReduce가 있습니다. 각 AllReduce는 가중치 행렬, 즉 $DF$ 파라미터 배열 또는 $2DF$ 바이트에 대한 것입니다. 이를 모두 종합하면 단일 레이어의 AllReduce 총 시간은 다음과 같습니다.
\[\begin{align} T_\text{comms} &= \frac{2 \cdot 2 \cdot 2 \cdot D \cdot F}{W_\text{ici}}. \\ \end{align}\]Matmul time: 각 레이어는 순방향 패스에서 두 개의 matmul, 또는 역방향 패스에서 네 개의 matmul로 구성되며, 각각은 $2(B/X)DF$ FLOPs를 필요로 합니다. 따라서 역방향 패스의 단일 레이어에 대해 다음을 얻습니다.
\[\begin{align} T_\text{math} &= \frac{2 \cdot 2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ \end{align}\]중첩하므로, 레이어당 총 시간은 이 두 수량의 최대값입니다:
\[\begin{aligned} T &\approx \max(\frac{8 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{8 \cdot D \cdot F}{W_\text{ici}}) \\ T &\approx 8 \cdot D \cdot F \cdot \max(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}) \end{aligned}\]\(T_\text{math}/T_\text{comms} > 1\)일 때, 또는 다음과 같을 때 compute-bound가 됩니다.
\[\begin{align} \frac{B}{X} > \frac{C}{W_\text{ici}}. \end{align}\]결론은 데이터 병렬 처리로 compute-bound를 유지하려면 디바이스당 배치 크기 \(B / X\)가 ICI 연산 강도(operational intensity) $C / W_\text{ici}$를 초과해야 한다는 것입니다. 이는 궁극적으로 계산 시간은 디바이스당 배치 크기에 비례하는 반면, 통신 시간은 이 수량과 무관하다는 사실(모델 가중치를 전송하고 있기 때문)의 결과입니다. $B > C/W_\text{ici}$ 조건이 단일 디바이스 compute-bound 규칙 $B > 240$과 유사하다는 점에 주목하세요. 그 경우에도 규칙은 계산 시간이 배치 크기에 비례하는 반면 데이터 전송 크기는 ($B \ll F, D$ 범위에서) 배치 크기와 무관하다는 사실에서 비롯되었습니다.
규모 감각을 얻기 위해 실제 숫자를 넣어 봅시다. TPUv5p의 경우 ICI를 통한 1D 데이터 병렬 처리에 대해 C=4.6e14이고 W=2 * 9e10이므로, 통신 병목 현상을 피하려면 칩당 배치 크기가 최소 2,550이어야 합니다. 여러 축에 대해 데이터 병렬 처리를 수행할 수 있으므로, TPUv5p pod의 세 축 모두를 순수 데이터 병렬 처리에 할애하면 대역폭 $W_\text{ici}$를 3배로 늘릴 수 있고 TPU당 BS=850 또는 pod(8960 칩)당 배치 7.6M 토큰까지 줄일 수 있습니다! 이는 순수 데이터 병렬 처리에 의해 병목 현상이 발생하기가 꽤 어렵다는 것을 말해줍니다!
Note [context parallelism]: 이 섹션 전체에서 $B$는 항상 토큰 단위의 총 배치 크기를 나타냅니다. 하지만 분명히 배치는 많은 다른 시퀀스로 구성되어 있는데, 이는 어떻게 작동할까요? MLP에 관한 한, 토큰은 토큰입니다! 같은 시퀀스에 속하든 다른 두 시퀀스에 속하든 상관없습니다. 따라서 우리는 배치 차원과 시퀀스 차원 모두에 대해 데이터 병렬 처리를 수행할 수 있습니다. 이를 컨텍스트 병렬 처리(context parallelism) 또는 시퀀스 병렬 처리(sequence parallelism)라고 부르지만, 단순히 또 다른 종류의 데이터 병렬 처리라고 생각할 수 있습니다. 어텐션은 시퀀스 간 계산을 수행하므로 MLP보다 까다롭지만, 이는 어텐션 중에 KV 또는 Q를 gather하고 FLOPs와 통신을 신중하게 중첩시켜(일반적으로 “ring attention”이라는 것을 사용) 처리할 수 있습니다. 이 섹션 전체에서 우리는 시퀀스 차원을 완전히 무시하고 일정량의 배치 또는 시퀀스 병렬 처리를 가정할 것입니다.
Note on multiple mesh axes: 여러 축이 사용 가능한 대역폭에 어떤 영향을 미치는지 빠르게 짚고 넘어가야 합니다. 주어진 병렬 처리 전략에 여러 메시 축을 사용할 때 더 많은 대역폭을 얻습니다.
Syntax: \(\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]\)
완전 샤딩된 데이터 병렬 처리(종종 FSDP 또는 ZeRO-sharding
AllReduce가 AllGather와 ReduceScatter로 분해될 수 있다는 것을 기억하실 것입니다(섹션 3에서). 즉, 표준 데이터 병렬 처리를 위해 전체 그라디언트 AllReduce를 수행하는 대신, 칩에 가중치와 옵티마이저 상태를 샤딩하고, 순방향 패스 동안 각 레이어에서 AllGather하고, 역방향 패스 동안 가중치에 대해 추가 비용 없이 ReduceScatter할 수 있습니다.
Fully-Sharded Data Parallelism (FSDP):
Forward pass: compute Loss[BX]
Backward pass: compute dWout[F, DX], dWin[DX, F]
이것은 불필요한 연산을 수행하거나 불필요한 상태를 저장하지 않기 때문에 “ZeRo Overhead sharding”에서 유래하여 “ZeRO Sharding”이라고도 합니다. ZeRO-{1,2,3}은 각각 옵티마이저 상태, 그라디언트, 가중치를 이런 방식으로 샤딩하는 것을 지칭하는 데 사용됩니다. 모두 동일한 통신 비용을 가지므로
왜 이렇게 할까요? 표준 데이터 병렬 처리는 많은 중복 작업을 포함합니다. 각 TPU는 전체 그라디언트를 AllReduce한 다음 전체 옵티마이저 상태를 업데이트하고(모든 TPU에서 동일한 작업), 파라미터를 업데이트합니다(다시 완전히 중복됨). ZeRO 샤딩(그라디언트/옵티마이저 상태 샤딩)의 경우, AllReduce 대신 그라디언트를 ReduceScatter하고, 옵티마이저 상태의 샤드만 업데이트하고, 파라미터의 샤드를 업데이트한 다음, 순방향 패스에 필요할 때 파라미터를 AllGather할 수 있습니다.
언제 통신에 의해 병목이 발생할까요? 역방향 패스의 각 AllReduce가 AllGather + ReduceScatter가 되었기 때문에 우리의 상대적인 FLOPs 및 통신 비용은 순수 데이터 병렬 처리와 정확히 동일합니다. AllReduce는 각각 절반의 비용을 가진 AllGather와 ReduceScatter로 구현된다는 것을 상기하세요. 여기서는 역방향 패스와 동일한 FLOPs 대 통신 비율을 가지므로 순방향 패스를 모델링합니다:
\[\begin{aligned} T_\text{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ T_\text{comms} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \\ T_\text{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ T_\text{comms} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \\ T &\approx \max\left(\frac{4 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{4 \cdot D \cdot F}{W_\text{ici}}\right) \\ T &\approx 4 \cdot D \cdot F \cdot \max\left(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}\right) \end{aligned}\]따라서 순수 데이터 병렬 처리와 마찬가지로, \(B / X > C / W_\text{ici}\)일 때, 즉 디바이스당 배치 크기 $B/X$가 “ICI 연산 강도” $C/W_\text{ici}$ (v5p의 경우 4.59e14 / 1.8e11 = 2550)를 초과할 때 compute bound입니다. 이는 우리에게 아주 좋은 일입니다. 왜냐하면 순수 데이터 병렬 처리에 대해 compute-bound일 만큼 디바이스당 배치 크기가 크다면, compute-bound 영역을 벗어나는 것에 대해 걱정하지 않고 단순히 FSDP로 업그레이드하여 엄청난 양의 파라미터 및 옵티마이저 상태 메모리를 절약할 수 있다는 것을 의미하기 때문입니다! 순방향 패스에 통신을 추가해야 했지만, 이 비용은 순방향 패스 FLOPs와 겹치기 때문에 중요하지 않습니다.
Takeaway: FSDP와 순수 데이터 병렬 처리는 모두 디바이스당 배치 크기가 $2550 / M_X$ 미만일 때 TPUv5에서 대역폭 병목이 발생합니다. 여기서 $M_X$는 메시 축의 수입니다.
예를 들어, DeepSeek-V2(훈련 배치 크기에 대한 정보를 공개한 유일한 최신 강력한 모델 중 하나)는 ~40M 토큰의 배치 크기를 사용했습니다. 이를 통해 대역폭 한계에 도달하기 전에 대략 47,000개의 칩, 즉 약 5개의 TPUv5 pod로 확장할 수 있습니다.
LLaMA-3 70B의 경우, 약 6.3e24 (15e12 * 70e9 * 6) FLOPs로 훈련되었으며, 16M 토큰 배치를 대략 16e6 / (2550 / 3) = 18,823개의 칩(대략 8960개 칩의 2개 pod)으로 나눌 수 있으며, 각 칩은 50% 피크 FLOPs 활용률(종종 MFU라고 함)에서 4.59e14 FLOPs를 실행하여, 약 17일 만에 훈련할 수 있습니다. 나쁘지 않습니다! 하지만 더 잘할 수 있는 방법을 알아봅시다.
Note on critical batch size: 다소 직관적이지 않게도, 총 배치 크기가 줄어들수록(고정된 칩 수에서) 통신 병목 현상이 더 많이 발생합니다. 데이터 병렬 처리와 FSDP를 사용하면 배치 크기를 계속 늘릴 수만 있다면 임의로 많은 칩으로 확장할 수 있습니다! 그러나 실제로 배치 크기가 증가함에 따라 그라디언트가 거의 노이즈가 없어지기 때문에 훈련에서 수확 체감을 보는 경향이 있습니다. 또한 때때로 훈련 불안정을 봅니다. 따라서 “무제한 컴퓨팅 체제”에서 최적의 샤딩 방식을 찾는 게임은 종종 스케일링 법칙에 의해 결정된 고정된 배치 크기와 알려진(큰) 칩 수에서 시작하여, 그 작은 배치 크기를 그렇게 많은 칩에 맞출 수 있는 파티셔닝을 찾는 것을 목표로 합니다.
Syntax: \(\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]\)(결국 FSDP와 결합하기 위해\(Y\)를 사용합니다)
완전 샤딩된 데이터 병렬 AllReduce에서는 가중치를 칩 간에 이동시킵니다. 모델의 피드포워드 차원을 샤딩하고 레이어 동안 활성화를 이동시킬 수도 있습니다. 이를 “1D 모델 병렬 처리” 또는 Megatron 샤딩
언급했듯이, In[B, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[B, DY] 는 첫 번째 matmul 전에 활성화를 gather해야 함을 의미합니다. 활성화가 가중치보다 작을 때 이것은 ZeRO 샤딩보다 저렴합니다. 이는 일반적으로 일정량의 ZeRO 샤딩이 추가된 경우에만 해당됩니다(gather의 크기를 줄여줌). 이것이 우리가 ZeRO 샤딩과 텐서 병렬 처리를 혼합하는 경향이 있는 이유 중 하나입니다.
Tensor Parallelism:
Forward pass: compute Loss[B]
Backward pass: compute dWout[FY, D], dWin[D, FY]
텐서 병렬 처리의 좋은 점 중 하나는 Transformer 순방향 패스의 두 행렬과 잘 상호 작용한다는 것입니다. 순진하게는 두 행렬 각각 뒤에 AllReduce를 수행할 것입니다. 그러나 여기서는 먼저 In[B, DY] * Win[D, FY] -> Tmp[B, FY] 를 수행한 다음 Tmp[B, FY] * Wout[FY, D] -> Out[B, DY] 를 수행합니다. 이는 AllReduce를 수행하는 대신 시작 부분에서 In을 AllGather하고 끝 부분에서 Out을 ReduceScatter한다는 것을 의미합니다.
얼마나 비용이 들까요? 순방향 패스만 모델링해 보겠습니다. 역방향 패스는 여기서 각 연산의 전치일 뿐입니다. 1D 텐서 병렬 처리에서는 첫 번째 matmul 전에 활성화를 AllGather하고, 두 번째 후에 ReduceScatter하며, 한 번에 2바이트(bf16)를 보냅니다. 통신에 의해 병목이 발생하는 시점을 알아봅시다.
\[\begin{align} T_\text{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\ T_\text{comms} & = T_\text{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\ T_\text{comms} & = \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\\ \textnormal{T} & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}, \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\right) \end{align}\]연산 비용이 통신 비용보다 크기를 원하므로 다음을 얻습니다:
\[\begin{align} \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} > \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}} \end{align}\] \[\begin{align} \frac{F}{Y \cdot C} > \frac{1}{W_\text{ici}} \end{align}\] \[\begin{align} F > Y \cdot \frac{C}{W_\text{ici}} \end{align}\]따라서 예를 들어, TPUv5p의 경우 bf16에서 $C / W_{ici} = 2550$이므로, $Y < F / 2550$까지만 텐서 병렬 처리를 수행할 수 있습니다. 여러 ICI 축이 있는 경우 $T_\text{comms}$가 $M_Y$배만큼 감소하므로 $Y < M_Y \cdot F / 2550$을 얻습니다.
Takeaway: 텐서 병렬 처리는 $Y > M_Y \cdot F / 2550$일 때 통신 병목이 발생합니다. 대부분의 모델에서 이는 8에서 16방향 텐서 병렬 처리 사이입니다.
이것은 연산 정밀도에 의존하지 않는다는 점에 유의하세요. 예를 들어 int8의 경우 TPUv5p에서 \(C_\text{int8} / W_{ici}\)는 \(2550\)대신\(5100\)이지만 통신량도 절반으로 줄어들므로 두 배수 요소가 상쇄됩니다.
몇 가지 예를 생각해 봅시다:
\(D = 8192,\) \(F \approx 30,000\)인 LLaMA 3-70B가 있는 TPUv5p에서는 8방향 텐서 병렬 처리를 편안하게 수행할 수 있지만 16방향 텐서 병렬 처리에서는 통신 병목이 발생합니다. 모델 8방향 모델 샤딩에 필요한 F는 20k입니다.
Gemma 7B의 경우 \(F \approx 50k\)이므로 19방향 텐서 병렬 처리에서 통신 병목이 발생합니다. 즉, 16방향을 수행해도 좋은 성능을 볼 수 있습니다.
Syntax: \(\text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y]\)
FSDP와 텐서 병렬 처리의 좋은 점은 결합될 수 있다는 것입니다. Win 과 Wout 을 두 축 모두에 샤딩함으로써 메모리와 연산을 모두 절약합니다. X를 따라 B를 샤딩하기 때문에 모델 병렬 AllGather의 크기를 줄이고, Y를 따라 F를 샤딩하기 때문에 FSDP의 통신 오버헤드를 줄입니다. 이는 두 가지를 결합하면 위에서 본 것보다 훨씬 더 낮은 유효 배치 크기에 도달할 수 있음을 의미합니다.
Forward pass: compute Loss[B]
Backward pass: compute dWout[FY, DX], dWin[DX, FY]
FSDP와 TP의 올바른 조합은 무엇일까요? 간단하지만 핵심적인 격언은 FSDP는 가중치를 이동시키고 텐서 병렬 처리는 활성화를 이동시킨다는 것입니다. 즉, 배치 크기가 줄어들수록(특히 데이터 병렬 처리를 더 많이 할수록) 샤드당 활성화가 작아지기 때문에 텐서 병렬 처리가 더 저렴해집니다.
따라서 두 가지를 결합하면 복제본당 최소 배치 크기를 훨씬 더 낮출 수 있습니다. 위와 같은 방식으로 최적의 FSDP 및 TP 양을 계산할 수 있습니다:
\(X\)를 FSDP에 할당된 칩 수, \(Y\)를 텐서 병렬 처리에 할당된 칩 수라고 합시다. \(N\)을 슬라이스의 총 칩 수라고 하고, \(N=XY\)입니다. \(M_X\)와 \(M_Y\)를 각각 FSDP와 TP를 수행하는 메시 축의 수라고 합시다(이들의 합은 대략 3이어야 함). FLOP당 통신이 가장 많은 순방향 패스만 모델링하겠습니다. 그러면 위 알고리즘의 통신을 더하면 다음과 같습니다.
\[T_\text{FSDP comms}(B, X, Y) = \frac{2\cdot 2\cdot D \cdot F}{Y \cdot W_\text{ici} \cdot M_X}\]\(T_\text{TP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y}\) \(T_\text{TP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y}\)
마찬가지로 총 FLOPs 시간은 다음과 같습니다.
\[T_\text{math} = \frac{2\cdot 2 \cdot B \cdot D \cdot F}{N \cdot C}.\]분석을 단순화하기 위해 두 가지 가정을 합니다. 첫째, $X$와 $Y$가 정수가 아닌 값을 가질 수 있도록 허용합니다($XY=N$을 만족하는 양수인 한). 둘째, $X$와 $Y$ 축의 통신을 서로 완전히 중첩할 수 있다고 가정합니다. 두 번째 가정하에 총 통신 시간은 다음과 같습니다.
\[T_\text{comms} = \max\left(T_\text{FSDP comms}, T_\text{TP comms}\right)\]어떤 조건에서 compute-bound가 될지 묻기 전에, 총 통신을 최소화하기 위한 $X$와 $Y$의 최적 값을 찾아봅시다. FLOPs는 $X$와 $Y$에 독립적이므로, 최적의 설정은 단순히 통신을 최소화하는 것입니다. 이를 위해 $T_\text{comms}$를 $X$와 $Y$가 아닌 $X$와 $N$(시스템의 칩 수이므로 고정됨)으로 다시 작성해 보겠습니다.
\[T_\text{comms} (X) = \frac{4D}{W_\text{ici}} \max\left(\frac{F \cdot X}{N \cdot M_X}, \frac{B}{X \cdot M_Y}\right)\]$T_\text{FSDP comms}$는 $X$에 대해 단조 증가하고 $T_\text{TP comms}$는 $X$에 대해 단조 감소하므로, 최대값은 $T_\text{FSDP comms} = T_\text{TP comms}$일 때 최소화되어야 합니다. 이는 다음과 같을 때 발생합니다.
\[\begin{align*} \frac{FX_{opt}}{M_X} = \frac{BN}{X_{opt} M_Y} \rightarrow \\ \frac{FX_{opt}}{M_X} = \frac{BN}{X_{opt} M_Y} \rightarrow \\ X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} \end{align*}\]이것은 매우 유용합니다! 주어진 $B$, $F$, $N$에 대해 최적의 FSDP 양을 알려줍니다. 규모 감각을 얻어 봅시다. 현실적인 값, 즉 $N = 64$ (4x4x4 칩 배열에 해당), $B=48,000$, $F=32768$을 대입하면 대략 $X\approx 13.9$를 얻습니다. 따라서 $X$를 16으로, $Y$를 4로 선택하여 계산된 최적값에 가깝게 할 것입니다.
Takeaway: 일반적으로 훈련 중에 최적의 FSDP 양은 \(X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N}\)입니다.
이제 모든 병렬 처리 전략에 대해 질문했던 것으로 돌아가 봅시다: 어떤 조건에서 compute-bound가 될까요? FLOPs와 통신을 중첩할 수 있으므로, 다음과 같을 때 compute-bound입니다.
\(\max\left(T_\text{FSDP comms}, T_\text{TP comms}\right) < T_\text{math}\) \(\max\left(T_\text{FSDP comms}, T_\text{TP comms}\right) < T_\text{math}\)
$\alpha \equiv C / W_\text{ici}$, 즉 ICI 연산 강도로 놓으면 단순화할 수 있습니다:
\(\max\left(\frac{F}{Y \cdot M_X}, \frac{B}{X \cdot M_Y}\right) < \frac{B \cdot F}{N \cdot \alpha}\) \(\max\left(\frac{F}{Y \cdot M_X}, \frac{B}{X \cdot M_Y}\right) < \frac{B \cdot F}{N \cdot \alpha}\)
좌변의 최대값이 같도록 $X_{opt}$를 계산했으므로 양쪽에 대입할 수 있습니다($Y_{opt} = N/X_{opt}$임에 유의).
\(\frac{F}{N \cdot W_\text{ici} \cdot M_X} \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} < \frac{B \cdot F}{N \cdot C}\) \(\frac{F}{N \cdot W_\text{ici} \cdot M_X} \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} < \frac{B \cdot F}{N \cdot C}\)
더 단순화하면 다음을 얻습니다.
\(\sqrt{\frac{B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha},\) \(\sqrt{\frac{B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha},\)
여기서 좌변은 통신 시간에 비례하고 우변은 계산 시간에 비례합니다. 계산 시간은 배치 크기에 선형적으로 비례하는 반면(병렬 처리에 관계없이), 통신 시간은 배치 크기의 제곱근으로 비례한다는 점에 유의하세요. 따라서 계산 시간 대 통신 시간의 비율도 배치 크기의 제곱으로 비례합니다:
\[\frac{T_\text{math}}{T_\text{comms}} = \frac{\sqrt{BF}\sqrt{M_X M_Y}}{\alpha \sqrt{N}}.\]이 비율이 1보다 커서 compute bound가 되도록 하려면 다음이 필요합니다.
\(\frac{B}{N} > \frac{\alpha^2}{M_X M_Y F}\) \(\frac{B}{N} > \frac{\alpha^2}{M_X M_Y F}\)
대략적인 숫자를 얻으려면 다시 $F=32,768$, $\alpha=2550$, $M_X M_Y=2$(3D 메시여야 하므로)를 대입합니다. 이는 대략 $B/N > 99$를 제공합니다. 이는 3D 메시를 가정할 때 $B/N$이 약 $850$을 초과해야 compute bound가 되는 순수 데이터 병렬 처리(또는 FSDP) 경우에 비해 대략 8배의 이득을 얻습니다.
Takeaway: 텐서 병렬 처리를 FSDP와 결합하면 $B/N$을 \(2550^2 / 2F\)로 낮출 수 있습니다. 이를 통해 칩당 100만큼 적은 배치를 처리할 수 있으며, 이는 FSDP만 사용할 때보다 대략 8배 더 작습니다.
아래에서 대표적인 4x4x4 칩 배열에 대해 텐서 병렬 처리(TP)만 사용하는 경우와 데이터 병렬 처리(FSDP)만 사용하는 경우와 비교하여 혼합 FSDP + TP의 FLOPs 대 통신 시간 비율을 플롯합니다. 순수 FSDP 병렬 처리가 매우 큰 배치 크기에서 우세하지만, 칩 수 대비 배치 크기가 대략 100에서 850 사이인 영역에서는 FSDP + TP 혼합 전략만이 1보다 큰 비율을 달성합니다.
다음은 TPU v5p 16x16x16의 또 다른 예로, 다양한 샤딩 방식에 대한 배치 크기 함수로서의 FLOPs 및 통신 시간을 보여줍니다.
검은색 곡선은 모델 FLOPs에 소비된 시간의 양을 의미하며, 이것이 모든 통신 비용보다 낮은 배치 크기는 엄격하게 comms bound입니다. 검은색 곡선이 녹색 곡선과 예측대로 약 4e5에서 교차하는 것을 알 수 있습니다.
다음은 다양한 배치 크기에 대한 총 계산 시간과 통신 시간을 보여주는 대화형 애니메이션입니다:
이것이 위와 대체로 일치함(FSDP=256, TP=16 부근에서 최소)을 알 수 있으며, 각각에 대한 축 수의 약간의 차이로 인한 약간의 오차 범위가 있습니다.
아마도 이전 섹션에서 파이프라이닝에 대해 이야기하는 것을 피했다는 것을 눈치채셨을 것입니다. 파이프라이닝은 TPU에서는 다소 덜 필수적인 GPU 병렬 처리를 위한 지배적인 전략입니다. 간단히 말해서, 파이프라인 훈련은 모델의 레이어를 여러 디바이스에 분할하고 순방향 및 역방향 패스 중에 파이프라인 단계 간에 활성화를 전달하는 것을 포함합니다. 알고리즘은 다음과 같습니다:
이 의사 코드는 Cloud TPU VM에서 실행되어야 합니다. 매우 효율적이거나 현실적이지는 않지만 데이터가 디바이스 간에 어떻게 전파되는지에 대한 감각을 제공합니다.
import jax
import jax.numpy as jnp
batch_size = 32
d_model = 128
d_ff = 4 * d_model
num_layers = len(jax.devices())
key = jax.random.PRNGKey(0)
# Pretend each layer is just a single matmul.
x = jax.random.normal(key, (batch_size, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))
def layer_fn(x, weight):
return x @ weight
# Assume we have num_layers == num_pipeline_stages
intermediates = [x]
for i in range(num_layers):
x = layer_fn(x, weights[i])
intermediates.append(x)
if i != num_layers - 1:
x = jax.device_put(x, jax.devices()[i+1])
def loss_fn(batch):
return jnp.mean(batch ** 2) # make up some fake loss function
loss, dx = jax.value_and_grad(loss_fn)(x)
for i in range(0, num_layers, -1):
_, f_vjp = jax.vjp(layer_fn, intermediates[i + 1], weights[i])
dx, dw = f_vjp(dx) # compute the jvp dx @ J(L)(x[i], W[i])
weights[i] = weights[i] - 0.01 * dw # update our weights
if i != 0:
dx = jax.device_put(dx, jax.devices()[i-1])
왜 이것이 좋은 생각일까요? 파이프라이닝은 여러 가지 이유로 훌륭합니다. 파이프라인 단계 간의 통신 비용이 낮으므로 낮은 대역폭 인터커넥트로도 매우 큰 모델을 훈련할 수 있습니다. 이는 TPU처럼 ICI로 밀집 연결되지 않은 GPU에서 종종 매우 유용합니다.
왜 이것이 어렵고/성가실까요? 위의 의사 코드에서 TPU 0이 거의 항상 유휴 상태라는 것을 눈치채셨을 것입니다! 파이프라인의 맨 처음과 마지막 단계에서만 작업을 수행합니다. 유휴 기간을 파이프라인 버블이라고 하며 처리하기가 매우 성가십니다. 일반적으로 마이크로배칭을 사용하여 먼저 이를 완화하려고 시도합니다. 이는 여러 개의 작은 배치를 파이프라인을 통해 보내 TPU 0이 전체 스텝 시간의 더 큰 부분 동안 활용되도록 유지합니다.
두 번째 접근 방식은 순방향 matmul $W_i @ x_i$, 역방향 $dx$ matmul $W_i @ \partial L / \partial x_{i+1}$, 그리고 $dW$ matmul $\partial L / \partial x_{i+1} @ x_i$를 신중하게 중첩하는 것입니다. 이들 각각은 약간의 FLOPs를 필요로 하므로 중첩하여 버블을 완전히 숨길 수 있습니다. 다음은 최근 DeepSeek v3 논문
TPU(상호 연결된 더 큰 pod를 가짐)에는 덜 중요하기 때문에 이에 대해 깊이 파고들지는 않겠지만, 주요 파이프라이닝 병목 현상을 이해하는 것은 좋은 연습입니다.
가장 큰 가능한 TPU 슬라이스는 8960개의 칩(및 2240개의 호스트)이 있는 TPU v5p SuperPod입니다. 이 크기를 넘어 확장하려면 데이터 센터 네트워킹(DCN) 경계를 넘어야 합니다. 각 TPU 호스트에는 이더넷을 통해 다른 TPU v5p pod에 호스트를 연결하는 하나 또는 여러 개의 NIC(네트워크 인터페이스 카드)가 장착되어 있습니다. TPU 섹션에서 언급했듯이, 각 호스트는 약 200Gbps(25GB/s)의 전이중(full-duplex) DCN 대역폭을 가지며, 이는 TPU당 약 6.25GB/s 전이중(송신) 대역폭입니다.
일반적으로 단일 pod를 넘어 확장할 때, ICI 도메인 내에서는 어떤 형태의 모델 병렬 처리 또는 FSDP를 수행하고, 여러 pod에 걸쳐서는 순수 데이터 병렬 처리를 수행합니다. $N$을 확장하려는 TPU 수, $M$을 ICI 연결 슬라이스당 TPU 수라고 합시다. DCN을 통해 AllReduce를 수행하려면 pod 세트에 대해 링 축소(ring-reduction)를 수행할 수 있습니다(역방향 패스에서):
\[T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot BDF}{N \cdot C}\] \[T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot DF}{M \cdot W_\text{dcn}}\]통신 대역폭은 $M$에 따라 확장됩니다. ICI와 달리 ICI 도메인을 키우고 더 많은 NIC를 확보함에 따라 총 대역폭이 증가하기 때문입니다. 단순화하면 다음과 같을 때 $T_\text{math} > T_\text{comms}$임을 알 수 있습니다.
\[\frac{B}{\text{slice}} > \frac{C}{W_\text{dcn}}\]TPU v5p의 경우 $\frac{C}{W_\text{dcn}}$은 약 4.46e14 / 6.25e9 = 71,360입니다. 이는 DCN을 통해 효율적으로 확장하려면 각 노드를 송신하는 데 필요한 ICI 도메인당 최소 배치 크기가 있음을 알려줍니다.
이것이 얼마나 문제인가요? 구체적인 예를 들어, BS=2M 토큰으로 TPU v5p에서 LLaMA-3 70B를 훈련하고 싶다고 가정해 봅시다. LLaMA-3 70B는 $F\approx 30,000$입니다. 위의 섹션에서 우리는 다음을 알고 있습니다:
요약하자면 BS=1M, 대략 X(FSDP)=1024 및 Y(TP)=8을 사용하여 훈련할 수 있는 좋은 레시피가 있지만, BS=2M에서는 DCN을 사용해야 합니다. 위에서 언급했듯이 $\text{71,360}$의 DCN 연산 강도를 가지고 있으므로, ICI 도메인당 배치 크기가 이보다 큰지 확인하기만 하면 됩니다. 2개의 pod를 사용하면 pod당 BS가 1M이고 GPU당 배치 크기가 111이므로 이는 우리에게 사소한 문제입니다(약간 아슬아슬할 수 있지만 이론적으로는 건전함).
Takeaway: 여러 TPU pod에 걸친 확장은 pod당 배치 크기가 최소 71k 토큰인 한 순수 데이터 병렬 처리를 사용하여 매우 간단합니다.
병렬 처리를 늘리거나 배치 크기를 줄이면 칩당 수행되는 연산량이 줄어들기 때문에 통신 병목이 더 많이 발생하는 경향이 있습니다.
합리적인 컨텍스트 길이(~32k)까지는 Transformer를 MLP 블록의 스택으로 모델링하고 레이어당 2/3개의 주요 matmul을 어떻게 샤딩하는지에 따라 여러 병렬 처리 방식을 정의할 수 있습니다.
훈련 중에는 4가지 주요 병렬 처리 방식을 고려하며, 각각 고유한 대역폭 및 연산 요구 사항이 있습니다(데이터 병렬 처리, FSDP, 텐서 병렬 처리).
| Strategy | Description |
|---|---|
| Data Parallelism | 활성화는 배치 샤딩되고, 다른 모든 것은 완전히 복제되며, 역방향 패스 중에 그라디언트를 all-reduce합니다. |
| FSDP | 활성화, 가중치 및 옵티마이저는 배치 샤딩되며, 가중치는 사용 직전에 gather되고, 그라디언트는 reduce-scatter됩니다. |
| Tensor Parallelism (aka Megatron, Model) | 활성화는 \(d_\text{model}\)에 따라 샤딩되고, 가중치는 \(d_{ff}\)에 따라 샤딩되며, 활성화는 Win 전에 gather되고, 결과는 Wout 후에 reduce-scatter됩니다. |
| Mixed FSDP + Tensor Parallelism | 위의 두 가지 모두이며, FSDP는 모델 샤딩된 가중치를 gather합니다. |
그리고 각 방법에 대한 “공식”은 다음과 같습니다:
\[\small \begin{array}{cc} \text{Strategy} & \text{Formula}\\ \hline \text{DP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D] \\ \text{FSDP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D] \\ \text{TP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \\ \text{TP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \\ \text{TP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \\ \text{TP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \\ \hline \end{array}\]순수 데이터 병렬 처리는 거의 유용하지 않습니다. 모델과 옵티마이저 상태가 파라미터 수의 10배인 바이트를 사용하기 때문입니다. 이는 수십억 개의 파라미터 이상을 메모리에 거의 맞출 수 없음을 의미합니다.
데이터 병렬 처리와 FSDP는 \(\text{batch size per shard} < C / W\)(네트워크의 연산 강도)일 때 comms bound가 됩니다. ICI의 경우 2,550이고 DCN의 경우 75,000입니다. 이는 더 많은 병렬 축으로 증가시킬 수 있습니다.
텐서 병렬 처리는 \(\lvert Y\rvert > F / 2550\)일 때 comms bound가 됩니다. 대부분의 모델에서 이는 약 8-16방향입니다. 이는 배치 크기와 무관합니다.
혼합 FSDP + 텐서 병렬 처리를 사용하면 배치 크기를 \(2550^2 / 2F \approx 100\)까지 낮출 수 있습니다. 이는 놀라울 정도로 낮습니다.
pod 간의 데이터 병렬 처리는 DCN-bound가 되기 전에 pod당 대략 75,000의 최소 배치 크기를 필요로 합니다.
기본적으로 배치 크기가 크거나 모델이 작다면 상황은 간단합니다. 데이터 병렬 처리를 하거나 DCN을 통한 FSDP + 데이터 병렬 처리를 할 수 있습니다. 중간 부분이 흥미로워지는 곳입니다.
이 섹션의 기본 모델로 LLaMA-2 13B를 사용합시다. 모델 세부 정보는 다음과 같습니다:
| hyperparam | value |
|---|---|
| L | 40 |
| D | 5,120 |
| F | 13824 |
| N | 40 |
| K | 40 |
| H | 128 |
| V | 32,000 |
LLaMA-2에는 별도의 임베딩 및 출력 행렬과 게이트형 MLP 블록이 있습니다.
Question 1: LLaMA-2 13B에는 얼마나 많은 파라미터가 있나요(어리석은 질문인 건 알지만 계산해 보세요)? Transformer Math에서와 같이 LLaMA-3에는 3개의 큰 FFW 행렬, 2개의 up-projection 및 1개의 down-projection이 있습니다. 이 섹션에서는 두 개의 “게이팅” einsum 행렬을 무시했지만, 이 섹션에서는 Win과 동일하게 동작합니다.
8.5e9 4.2e9 0.3e9 8.5e9 + 4.2e9 + 0.39e9 = 13.1e9, 예상대로입니다!Question 2: BS=16M 토큰으로 훈련하고 Adam을 사용한다고 가정해 봅시다. 잠시 병렬 처리를 무시하고, 모델의 파라미터, 옵티마이저 상태, 활성화에 사용되는 총 메모리는 얼마인가요? 파라미터는 bf16에, 옵티마이저 상태는 fp32에 저장하고 레이어당 세 번(세 개의 큰 matmul 이후) 활성화를 체크포인트한다고 가정합니다.
파라미터(bf16)와 두 개의 옵티마이저 상태(fp32, 1차 및 2차 모멘트 누적기)에 사용되는 총 메모리는 (2 + 4 + 4) * 13e9 ~ 130GB입니다. 처음 두 개의 matmul 후의 활성화는 모양이 $BF$이고 마지막 후에는 $BD$이므로(위의 Transformer 다이어그램에 따라), bf16의 총 메모리는 $2 \cdot L \cdot (BD + 2 * BF) = 2LB \cdot (D + 2F)$ 또는 2 * 40 * 16e6 * 5,120 * (1 + 2 * 2.7) ~ 4.2e13 = 42TB입니다(B=16e6이므로). 다른 모든 활성화는 다소 무시할 수 있습니다.
Question 3: TPUv5p 16x16x16 슬라이스에서 32k 시퀀스 길이와 총 배치 크기 3M 토큰으로 훈련하고 싶다고 가정해 봅시다. 위와 같이 bfloat16 가중치와 float32 옵티마이저를 사용하고 싶다고 가정합니다.
먼저 몇 가지 숫자를 적어 봅시다. 32k 시퀀스 길이와 3M 배치 크기로, 시퀀스 배치 크기는 96입니다. TPU v5p 16x16x16 슬라이스에는 총 393TB의 HBM이 있습니다.
순수 데이터 병렬 처리는 사용할 수 없습니다. 각 칩에 파라미터와 옵티마이저 상태를 복제하는데, 이는 이미 약 130GB(Q2에서)로 칩당 HBM(96GB)보다 많기 때문입니다.
메모리만 보는 것으로 시작해 봅시다. Q2에서 BS=16M을 3M으로 바꾸면 ~7.86e12 총 체크포인트 활성화를 얻고, 1.3e11 옵티마이저 상태를 더하면 거의 정확히 8e12 = 8TB가 됩니다. TPUv5p 슬라이스는 총 393TB의 HBM을 가지고 있으므로 HBM 제한 아래에 안전하게 있습니다. 다음으로 comms-bound인지 compute-bound인지 살펴봅시다. 4096개의 칩과 3개의 병렬 처리 축으로 850 * 4096 = 3.48M 토큰의 최소 배치 크기를 수행할 수 있습니다. 이는 3M 배치 크기보다 약간 높습니다. 따라서 실제로 comms-bound이며, 이는 슬픈 일입니다. 따라서 일반적인 대답은 아니요, FSDP만으로는 할 수 없습니다.
이제 우리의 주된 관심사가 comms-bound라는 것을 알았으므로 숫자를 대입해 봅시다. 우선 위에서 혼합 FSDP + 텐서 병렬 처리를 사용하는 칩당 배치 크기가 $2550^2 / 2F = 235$ 이상이어야 한다는 것을 알고 있습니다. 즉, 이론적으로 이것을 할 수 있습니다! 각각 얼마인지 알아봅시다.
$X_{opt} = \sqrt((F / B) * (M_X / M_Y) * N)$ 규칙이 있으므로, 여기서는 sqrt(3e6 * 2 * 4096 / 13824) = 1333을 가지며, 이는 대략 1024 방향 DP와 4 방향 TP를 수행한다는 것을 의미합니다. TPU당 메모리는 (2)와 같을 것이며, 스텝 시간은 6 * 3e6 * 13e9 / (4096 * 4.6e14 * 0.4) = 300ms가 될 것입니다.
위에서 Transformer 레이어 순방향 패스를 Out[B, D] = In[B, D] *D Win[D, F] *F Wout[F, D]로 단순화했습니다. 역방향 패스에 필요한 통신은 어떻게 유도할까요?
이는 단일 matmul Y = X * A에 대한 이전 섹션의 규칙에서 아주 자연스럽게 따릅니다:
\[\frac{dL}{dA} = \frac{dL}{dY}\frac{dY}{dA} = X^T \left(\frac{dL}{dY}\right)\] \[\frac{dL}{dX} = \frac{dL}{dY}\frac{dY}{dX} = \left(\frac{dL}{dY}\right) A^T\]이를 사용하여 다음 공식을 얻습니다(Tmp[B, F]는 In[B, D] * Win[D, F]를 나타냄):
이 공식들은 샤딩에 대한 언급이 없는 수학적 진술이라는 점에 유의하세요. 역방향 패스의 작업은 이 네 가지 수량을 계산하는 것입니다. 따라서 필요한 통신을 파악하려면 위의 네 가지 방정식(Tmp, dOut, Wout, Win)에서 matmul될 모든 수량의 샤딩(병렬화 방식에 의해 지정됨)을 가져와 샤딩된 matmul의 규칙을 사용하여 수행해야 할 통신을 파악하면 됩니다. dOut은 Out과 같은 방식으로 샤딩된다는 점에 유의하세요.