How to Parallelize a Transformer for Training

Part 5 of How To Scale Your Model (Part 4: Transformers | Part 6: Training LLaMA)

여기서는 LLM 훈련 중에 사용되는 4가지 주요 병렬 처리 방식인 데이터 병렬 처리(data parallelism), 완전 샤딩된 데이터 병렬 처리(FSDP), 텐서 병렬 처리(tensor parallelism), 그리고 파이프라인 병렬 처리(pipeline parallelism)에 대해 논의합니다. 각각에 대해 어느 시점에서 통신에 의해 병목 현상이 발생하는지 계산합니다.

번역 안내: 원저자(Jacob Austin)의 허락을 받아 원문을 번역 중입니다.
해당 글의 1인칭은 원문 저자를 지칭합니다.
원문: How to Scale Your Model
번역: 신종훈

What Do We Mean By Scaling?

“모델 스케일링”의 목표는 훈련이나 추론에 사용되는 칩의 수를 늘리면서 처리량을 비례적으로, 즉 선형적으로 증가시키는 것입니다(이를 strong scaling이라고 합니다). 단일 칩에서의 성능은 메모리 대역폭과 FLOPs 간의 트레이드오프에 달려있지만, 클러스터 수준에서의 성능은 칩 간 통신을 유용한 FLOPS와 중첩시켜 숨기는 것에 달려있습니다. 이는 간단하지 않은데, 칩 수를 늘리면 통신 부하가 늘어나는 동시에 이를 숨기는 데 사용할 수 있는 디바이스당 연산량은 줄어들기 때문입니다. 섹션 3에서 보았듯이, 샤딩된 행렬 곱셈은 종종 비싼 AllGather나 ReduceScatter를 필요로 하며, 이는 TPU가 유용한 작업을 하는 것을 막을 수 있습니다. 이 섹션의 목표는 이것들이 언제 너무 비싸지는지 알아내는 것입니다.

이 섹션에서는 네 가지 일반적인 병렬 처리 방식에 대해 논의합니다: (순수) 데이터 병렬 처리(data parallelism), 완전 샤딩된 데이터 병렬 처리(FSDP / ZeRO sharding), 텐서 병렬 처리(tensor parallelism)(모델 병렬 처리라고도 함), 그리고 (간략하게) 파이프라인 병렬 처리(pipeline parallelism)입니다. 각각에 대해 어떤 통신 비용이 발생하고 어느 시점에서 그 비용이 연산 비용에 병목이 되기 시작하는지 보여줄 것입니다.우리는 통신 제한에 초점을 맞출 것입니다. 메모리 용량 제약도 중요하지만, 사전 훈련 중 매우 많은 수의 칩과 rematerialization (activation checkpointing)을 사용할 때는 일반적으로 우리를 제한하지 않기 때문입니다. 또한 여기서는 MoE를 위한 expert parallelism에 대해서는 논의하지 않습니다. 이는 설계 공간을 상당히 확장시키기 때문에, 밀집(dense) Transformer의 기본 사례만 다룹니다. 이 섹션에서는 칩 간 통신 비용에만 집중하면 됩니다. 단일 칩 배치 크기가 충분히 크다면 HBM에서 MXU로의 데이터 전송은 이미 연산과 중첩되기 때문입니다.

이 섹션 전체에서 계산을 단순화하기 위해 다음 표기법을 사용할 것입니다.

Notation Meaning (model parameters)
D dmodel ( hidden dimension/residual stream dim)
F dff (feed-forward dimension)
B Batch dimension (배치의 토큰 수; 디바이스별이 아닌 전체)
T Sequence length
L Number of layers in the model
Notation Meaning (hardware characteristic)
C FLOPS/s per chip
W Network bandwidth (bidirectional, often subscripted as e.g. $W_{\text{ici}}$ or $W_{\text{dcn}}$
X Number of chips along mesh axis X
Y Number of chips along an alternate mesh axis, labeled Y
Z Number of chips along a third mesh axis, labeled Z

단순화를 위해, Transformer를 MLP 블록의 스택으로 근사하겠습니다. 섹션 4에서 보았듯이 어텐션은 대규모 모델에서 FLOPs의 비교적 적은 부분을 차지하기 때문입니다. 또한 게이팅 matmul을 무시하고 각 레이어에 대해 다음과 같은 간단한 구조만 남겨두겠습니다:

Figure: 단순화된 Transformer 레이어. 각 FFW 블록을 두 개의 행렬 Win: bf16[D, F] (up-projection) 과 Wout: bf16[F, D] (down-projection), 그리고 입력 In: bf16[B, D] 의 스택으로 취급합니다.
다음은 병렬 처리 없는 간단한 Transformer의 전체 알고리즘입니다.

Forward pass: compute Loss[B]

  1. Tmp[B, F] = In[B, D] *D Win[D, F]
  2. Out[B, D] = Tmp[B, F] *F Wout[F, D]
  3. Loss[B] = …

Backward pass: compute dWout[F, D], dWin[D, F]

  1. dOut[B, D] = …
  2. dWout[F, D] = Tmp[B, F] *B dOut[B, D]
  3. dTmp[B, F] = dOut[B, D] *D Wout[F, D]
  4. dWin[D, F] = In[B, D] *B dTmp[B, F]
  5. dIn[B, D] = dTmp[B, F] *F Win[D, F] (needed for previous layers)

통신이 추가된 알고리즘과 비교하기 위해 이를 제공합니다.

다음은 우리가 논의할 4가지 병렬 처리 방식입니다. 각 방식은 위 다이어그램의 In, Win, Wout, Out에 대한 샤딩에 의해 고유하게 정의되는 것으로 생각할 수 있습니다.

1. Data parallelism: 활성화는 배치에 따라 샤딩되고, 파라미터와 옵티마이저 상태는 각 디바이스에 복제됩니다. 통신은 역전파 중에만 발생합니다.

\[\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]\]

2. Fully-sharded data parallelism (FSDP or ZeRO-3): 활성화는 배치에 따라 샤딩되며(순수 데이터 병렬 처리처럼), 파라미터는 동일한 메시 축을 따라 샤딩되고 순방향 패스에서 사용되기 직전에 AllGather됩니다. 옵티마이저 상태도 배치에 따라 샤딩됩니다. 중복 메모리를 줄입니다.

\[\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]\]

3. Tensor parallelism (Megatron sharding 또는 model parallelism이라고도 함): 활성화는 D ($d_\text{model}$)에 따라 샤딩되고, 파라미터는 F ($d_{ff}$)에 따라 샤딩됩니다. 각 블록 전후에 활성화를 AllGather하고 ReduceScatter합니다. FSDP와 호환됩니다.

\[\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]\]

4. Pipeline parallelism: 가중치는 레이어 차원에 따라 샤딩되고, 활성화는 마이크로배치되어 레이어 차원에 따라 굴러갑니다(rolled). 파이프라인 단계 간 통신은 최소화됩니다(단일 홉으로 활성화만 이동). 표기법을 남용하자면:

\(\text{In}[L_Z, B, D][i] \cdot_D W_\text{in}[L_Z, D, F][i] \cdot_F W_\text{out}[L_Z, F, D][i] \rightarrow \text{Out}[L_Z, B, D][i]\) \(\text{In}[L_Z, B, D][i] \cdot_D W_\text{in}[L_Z, D, F][i] \cdot_F W_\text{out}[L_Z, F, D][i] \rightarrow \text{Out}[L_Z, B, D][i]\)

Data Parallelism

Syntax: \(\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]\)

모델이 아주 작은 배치 크기(>240 토큰, compute-bound가 되도록)라도 단일 칩에 맞는다면, 항상 단순한 데이터 병렬 처리를 사용해야 합니다. 순수 데이터 병렬 처리는 TPU 수가 배치 크기보다 작은 한 활성화를 원하는 수의 TPU에 걸쳐 분할합니다. 순방향 패스에는 통신이 포함되지 않지만, 모든 단계의 끝에서 각 TPU는 파라미터를 업데이트하기 전에 로컬 그라디언트를 동기화하기 위해 AllReduce를 수행합니다.

Figure: 순수 데이터 병렬 처리의 다이어그램(순방향 패스). 활성화(왼쪽)는 배치 차원을 따라 완전히 샤딩되고 가중치는 완전히 복제되므로, 각 TPU는 가중치의 동일한 사본을 갖습니다. 이는 가중치의 총 메모리가 N배 증가함을 의미하지만, 순방향 패스에서는 통신이 필요하지 않습니다.
다음은 순방향 및 역방향 패스에 대한 전체 알고리즘입니다. 간결함을 위해 dL/dOut을 dOut으로 표기합니다.

Pure Data Parallelism Algorithm:

Forward pass: compute Loss[BX]

  1. Tmp[BX, F] = In[BX, D] *D Win[D, F]
  2. Out[BX, D] = Tmp[BX, F] *F Wout[F, D]
  3. Loss[BX] = …

Backward pass: compute dWout[F, D], dWin[D, F]

  1. dOut[BX, D] = …
  2. dWout[F, D] {UX} = Tmp[BX, F] *B dOut[BX, D]
  3. dWout[F, D] = AllReduce(dWout[F, D] {UX}) (not on critical path, can be done async)
  4. dTmp[BX, F] = dOut[BX, D] *D Wout[F, D]
  5. dWin[D, F] {UX} = In[BX, D] *B dTmp[BX, F]
  6. dWin[D, F] = AllReduce(dWin[D, F] {UX}) (not on critical path, can be done async)
  7. dIn[BX, D] = dTmp[BX, F] *F Win[D, F] (needed for previous layers)

손실 함수의 세부 사항은 무시하고 $\text{Tmp} = W_\text{in} \cdot \text{In}$으로 축약합니다. 최종 손실은 평균 AllReduce(Loss[BX])이지만, 가중치 그라디언트를 평균화할 때 역방향 패스에서만 AllReduce를 계산하면 됩니다.

순방향 패스에는 통신이 없다는 점에 유의하세요 — 모두 역방향 패스에 있습니다! 역방향 패스는 또한 AllReduce가 “critical path”에 있지 않다는 훌륭한 속성을 가지고 있습니다. 즉, 각 AllReduce는 편리할 때 수행할 수 있으며 후속 작업을 수행하는 것을 차단하지 않습니다. 총 통신 비용이 총 연산 비용을 초과하면 전체 통신 비용이 여전히 병목이 될 수 있지만, 구현 관점에서는 훨씬 더 관대합니다. 모델/텐서 병렬 처리에는 이 속성이 없다는 것을 보게 될 것입니다.

왜 이렇게 할까요? 순수 데이터 병렬 처리는 배치 차원에 걸쳐 활성화를 분할하여 활성화 메모리 압박을 줄여줍니다. 배치 차원을 분할할 칩이 더 많다면 거의 임의로 배치 크기를 늘릴 수 있습니다. 특히 훈련 중에 활성화가 메모리 사용량을 지배하는 경우가 많으므로 이는 매우 유용합니다.

왜 이렇게 하지 않을까요? 순수 데이터 병렬 처리는 모델 파라미터나 옵티마이저 상태로 인한 메모리 압박을 줄이는 데 아무런 도움이 되지 않습니다. 즉, 파라미터 + 옵티마이저 상태가 단일 TPU에 맞지 않는 대규모의 흥미로운 모델에는 거의 유용하지 않습니다. 규모 감각을 주기 위해, 파라미터를 bf16으로, 옵티마이저 상태를 fp32로 AdamAdam은 파라미터, 1차 및 2차 모멘트 누적기를 저장합니다. 파라미터는 bfloat16이고 옵티마이저 상태는 float32이므로 파라미터당 `2 + 8 = 10` 바이트를 제공합니다.을 사용하여 훈련한다면, 맞을 수 있는 가장 큰 모델은 \(\text{TPU memory} / 10\) 파라미터를 가집니다. 따라서 96GB HBM과 순수 데이터 병렬 처리를 사용하는 TPUv5p 칩에서는 약 9B 파라미터입니다.

Takeaway: Adam과 순수 데이터 병렬 처리로 훈련할 수 있는 가장 큰 모델은 \(\text{num_params} = \text{HBM per device} / 10\)입니다. TPU v5p의 경우 대략 9B 파라미터입니다.이것은 gradient checkpoints를 포함하지 않으므로 실제로는 유용하지 않을 것입니다. 이것은 1 토큰 배치를 가정한 절대적인 하한입니다.

실제 모델 훈련에 유용하게 사용하려면 모델 파라미터나 옵티마이저를 적어도 부분적으로 샤딩해야 합니다.

언제 통신에 의해 병목이 발생할까요? 위에서 볼 수 있듯이 레이어당 두 개의 AllReduce가 있으며, 각각의 크기는 \(2DF\) (bf16 가중치용)입니다. 데이터 병렬 처리는 언제 통신 병목이 될까요?

위의 표에서와 같이, $C$ = 칩당 FLOPs, $W_{\text{ici}}$ = 양방향(bidirectional) 네트워크 대역폭, $X$ = 배치가 분할된 샤드 수이 파티셔닝이 ICI 메시 위에서 수행된다고 가정하므로 관련 네트워크 대역폭은 $W_\text{ici}$입니다.라고 합시다. 관련 matmul을 수행하는 데 필요한 시간 \(T_\text{math}\)와 필요한 통신 시간 \(T_\text{comms}\)를 계산해 봅시다. 이 병렬 처리 방식은 순방향 패스에서 통신이 필요하지 않으므로 역방향 패스에 대해서만 이 수량을 계산하면 됩니다.

Communication time: 이전 섹션에서 1D 메시에서 AllReduce를 수행하는 데 필요한 시간은 AllReduce되는 배열의 총 바이트 수와 ICI 대역폭 $W_\text{ici}$에만 의존한다는 것을 알고 있습니다. 구체적으로 AllReduce 시간은 $2 \cdot \text{total bytes} / W_\text{ici}$입니다. $W_\text{in}$과 $W_\text{out}$ 모두에 대해 AllReduce가 필요하므로 레이어당 2개의 AllReduce가 있습니다. 각 AllReduce는 가중치 행렬, 즉 $DF$ 파라미터 배열 또는 $2DF$ 바이트에 대한 것입니다. 이를 모두 종합하면 단일 레이어의 AllReduce 총 시간은 다음과 같습니다.

\[\begin{align} T_\text{comms} &= \frac{2 \cdot 2 \cdot 2 \cdot D \cdot F}{W_\text{ici}}. \\ \end{align}\]

Matmul time: 각 레이어는 순방향 패스에서 두 개의 matmul, 또는 역방향 패스에서 네 개의 matmul로 구성되며, 각각은 $2(B/X)DF$ FLOPs를 필요로 합니다. 따라서 역방향 패스의 단일 레이어에 대해 다음을 얻습니다.

\[\begin{align} T_\text{math} &= \frac{2 \cdot 2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ \end{align}\]

중첩하므로, 레이어당 총 시간은 이 두 수량의 최대값입니다:

\[\begin{aligned} T &\approx \max(\frac{8 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{8 \cdot D \cdot F}{W_\text{ici}}) \\ T &\approx 8 \cdot D \cdot F \cdot \max(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}) \end{aligned}\]

\(T_\text{math}/T_\text{comms} > 1\)일 때, 또는 다음과 같을 때 compute-bound가 됩니다.

\[\begin{align} \frac{B}{X} > \frac{C}{W_\text{ici}}. \end{align}\]

결론은 데이터 병렬 처리로 compute-bound를 유지하려면 디바이스당 배치 크기 \(B / X\)가 ICI 연산 강도(operational intensity) $C / W_\text{ici}$를 초과해야 한다는 것입니다. 이는 궁극적으로 계산 시간은 디바이스당 배치 크기에 비례하는 반면, 통신 시간은 이 수량과 무관하다는 사실(모델 가중치를 전송하고 있기 때문)의 결과입니다. $B > C/W_\text{ici}$ 조건이 단일 디바이스 compute-bound 규칙 $B > 240$과 유사하다는 점에 주목하세요. 그 경우에도 규칙은 계산 시간이 배치 크기에 비례하는 반면 데이터 전송 크기는 ($B \ll F, D$ 범위에서) 배치 크기와 무관하다는 사실에서 비롯되었습니다.

규모 감각을 얻기 위해 실제 숫자를 넣어 봅시다. TPUv5p의 경우 ICI를 통한 1D 데이터 병렬 처리에 대해 C=4.6e14이고 W=2 * 9e10이므로, 통신 병목 현상을 피하려면 칩당 배치 크기가 최소 2,550이어야 합니다. 여러 축에 대해 데이터 병렬 처리를 수행할 수 있으므로, TPUv5p pod의 세 축 모두를 순수 데이터 병렬 처리에 할애하면 대역폭 $W_\text{ici}$를 3배로 늘릴 수 있고 TPU당 BS=850 또는 pod(8960 칩)당 배치 7.6M 토큰까지 줄일 수 있습니다! 이는 순수 데이터 병렬 처리에 의해 병목 현상이 발생하기가 꽤 어렵다는 것을 말해줍니다!

Note [context parallelism]: 이 섹션 전체에서 $B$는 항상 토큰 단위의 총 배치 크기를 나타냅니다. 하지만 분명히 배치는 많은 다른 시퀀스로 구성되어 있는데, 이는 어떻게 작동할까요? MLP에 관한 한, 토큰은 토큰입니다! 같은 시퀀스에 속하든 다른 두 시퀀스에 속하든 상관없습니다. 따라서 우리는 배치 차원과 시퀀스 차원 모두에 대해 데이터 병렬 처리를 수행할 수 있습니다. 이를 컨텍스트 병렬 처리(context parallelism) 또는 시퀀스 병렬 처리(sequence parallelism)라고 부르지만, 단순히 또 다른 종류의 데이터 병렬 처리라고 생각할 수 있습니다. 어텐션은 시퀀스 간 계산을 수행하므로 MLP보다 까다롭지만, 이는 어텐션 중에 KV 또는 Q를 gather하고 FLOPs와 통신을 신중하게 중첩시켜(일반적으로 “ring attention”이라는 것을 사용) 처리할 수 있습니다. 이 섹션 전체에서 우리는 시퀀스 차원을 완전히 무시하고 일정량의 배치 또는 시퀀스 병렬 처리를 가정할 것입니다.

Note on multiple mesh axes: 여러 축이 사용 가능한 대역폭에 어떤 영향을 미치는지 빠르게 짚고 넘어가야 합니다. 주어진 병렬 처리 전략에 여러 메시 축을 사용할 때 더 많은 대역폭을 얻습니다.

Fully-Sharded Data Parallelism (FSDP)

Syntax: \(\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]\)

완전 샤딩된 데이터 병렬 처리(종종 FSDP 또는 ZeRO-sharding이라고 함)는 모델 옵티마이저 상태와 가중치를 데이터 병렬 샤드에 분할하고 필요에 따라 효율적으로 gather 및 scatter합니다. 순수 데이터 병렬 처리와 비교할 때, FSDP는 디바이스당 메모리 사용량을 대폭 줄이고 역방향 패스 FLOPs를 절약하며 오버헤드는 매우 적습니다.

Figure: FSDP는 Win의 축약 차원과 Wout의 출력 차원을 데이터 차원을 따라 샤딩합니다. 이는 메모리를 줄이지만(섹션 3에서), matmul을 수행하기 전에 W에 대한 가중치를 gather해야 합니다. 활성화(왼쪽)는 축약 차원을 따라 샤딩되지 않았다는 점에 유의하세요. 이것이 우리에게 gather를 강제하는 것입니다. 가중치 옵티마이저 상태도 마찬가지로 축약 차원을 따라 샤딩된다는 점에 유의하세요.

AllReduce가 AllGather와 ReduceScatter로 분해될 수 있다는 것을 기억하실 것입니다(섹션 3에서). 즉, 표준 데이터 병렬 처리를 위해 전체 그라디언트 AllReduce를 수행하는 대신, 칩에 가중치와 옵티마이저 상태를 샤딩하고, 순방향 패스 동안 각 레이어에서 AllGather하고, 역방향 패스 동안 가중치에 대해 추가 비용 없이 ReduceScatter할 수 있습니다.

다음은 FSDP의 전체 알고리즘입니다.

Fully-Sharded Data Parallelism (FSDP):

Forward pass: compute Loss[BX]

  1. Win[D, F] = AllGather(Win[DX, F]) (not on critical path, can do it during previous layer)
  2. Tmp[BX, F] = In[BX, D] *D Win[D, F] (can throw away Win[D, F] now)
  3. Wout[F, D] = AllGather(Wout[F, DX]) (not on critical path, can do it during previous layer)
  4. Out[BX, D] = Tmp[BX, F] *F Wout[F, D]
  5. Loss[BX] = …

Backward pass: compute dWout[F, DX], dWin[DX, F]

  1. dOut[BX, D] = …
  2. dWout[F, D] {UX} = Tmp[BX, F] *B dOut[BX, D]
  3. dWout[F, DX] = ReduceScatter(dWout[F, D] {UX}) (not on critical path, can be done async)
  4. Wout[F, D] = AllGather(Wout[F, DX]) (can be done ahead of time)
  5. dTmp[BX, F] = dOut[BX, D] *D Wout[F, D] (can throw away Wout[F, D] here)
  6. dWin[D,F] {UX} = dTmp[BX, F] *B In[BX, D]
  7. dWin[DX, F] = ReduceScatter(dWin[D, F] {UX}) (not on critical path, can be done async)
  8. Win[D, F] = AllGather(Win[DX, F]) (can be done ahead of time)
  9. dIn[BX, D] = dTmp[BX, F] *F Win[D, F] (needed for previous layers) (can throw away Win[D, F] here)

이것은 불필요한 연산을 수행하거나 불필요한 상태를 저장하지 않기 때문에 “ZeRo Overhead sharding”에서 유래하여 “ZeRO Sharding”이라고도 합니다. ZeRO-{1,2,3}은 각각 옵티마이저 상태, 그라디언트, 가중치를 이런 방식으로 샤딩하는 것을 지칭하는 데 사용됩니다. 모두 동일한 통신 비용을 가지므로엄밀히 말하면 FSDP는 순수 DP에는 없는 순방향 패스 통신을 추가하지만, 이는 역방향 패스와 동일한 비율이므로 통신 루프라인에 영향을 미치지 않아야 합니다. 여기서 핵심은 ZeRO-3가 역방향 패스 AllReduce를 AllGather와 ReduceScatter로 바꾼다는 것이며, 이는 동일한 총 통신량을 가집니다., 우리는 기본적으로 항상 파라미터, 그라디언트, 옵티마이저 상태를 디바이스 세트에 샤딩하는 ZeRO-3 샤딩을 수행할 수 있습니다.

왜 이렇게 할까요? 표준 데이터 병렬 처리는 많은 중복 작업을 포함합니다. 각 TPU는 전체 그라디언트를 AllReduce한 다음 전체 옵티마이저 상태를 업데이트하고(모든 TPU에서 동일한 작업), 파라미터를 업데이트합니다(다시 완전히 중복됨). ZeRO 샤딩(그라디언트/옵티마이저 상태 샤딩)의 경우, AllReduce 대신 그라디언트를 ReduceScatter하고, 옵티마이저 상태의 샤드만 업데이트하고, 파라미터의 샤드를 업데이트한 다음, 순방향 패스에 필요할 때 파라미터를 AllGather할 수 있습니다.

언제 통신에 의해 병목이 발생할까요? 역방향 패스의 각 AllReduce가 AllGather + ReduceScatter가 되었기 때문에 우리의 상대적인 FLOPs 및 통신 비용은 순수 데이터 병렬 처리와 정확히 동일합니다. AllReduce는 각각 절반의 비용을 가진 AllGather와 ReduceScatter로 구현된다는 것을 상기하세요. 여기서는 역방향 패스와 동일한 FLOPs 대 통신 비율을 가지므로 순방향 패스를 모델링합니다:

\[\begin{aligned} T_\text{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ T_\text{comms} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \\ T_\text{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ T_\text{comms} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \\ T &\approx \max\left(\frac{4 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{4 \cdot D \cdot F}{W_\text{ici}}\right) \\ T &\approx 4 \cdot D \cdot F \cdot \max\left(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}\right) \end{aligned}\]

따라서 순수 데이터 병렬 처리와 마찬가지로, \(B / X > C / W_\text{ici}\)일 때, 즉 디바이스당 배치 크기 $B/X$가 “ICI 연산 강도” $C/W_\text{ici}$ (v5p의 경우 4.59e14 / 1.8e11 = 2550)를 초과할 때 compute bound입니다. 이는 우리에게 아주 좋은 일입니다. 왜냐하면 순수 데이터 병렬 처리에 대해 compute-bound일 만큼 디바이스당 배치 크기가 크다면, compute-bound 영역을 벗어나는 것에 대해 걱정하지 않고 단순히 FSDP로 업그레이드하여 엄청난 양의 파라미터 및 옵티마이저 상태 메모리를 절약할 수 있다는 것을 의미하기 때문입니다! 순방향 패스에 통신을 추가해야 했지만, 이 비용은 순방향 패스 FLOPs와 겹치기 때문에 중요하지 않습니다.

Takeaway: FSDP와 순수 데이터 병렬 처리는 모두 디바이스당 배치 크기가 $2550 / M_X$ 미만일 때 TPUv5에서 대역폭 병목이 발생합니다. 여기서 $M_X$는 메시 축의 수입니다.

예를 들어, DeepSeek-V2(훈련 배치 크기에 대한 정보를 공개한 유일한 최신 강력한 모델 중 하나)는 ~40M 토큰의 배치 크기를 사용했습니다. 이를 통해 대역폭 한계에 도달하기 전에 대략 47,000개의 칩, 즉 약 5개의 TPUv5 pod로 확장할 수 있습니다.

LLaMA-3 70B의 경우, 약 6.3e24 (15e12 * 70e9 * 6) FLOPs로 훈련되었으며, 16M 토큰 배치를 대략 16e6 / (2550 / 3) = 18,823개의 칩(대략 8960개 칩의 2개 pod)으로 나눌 수 있으며, 각 칩은 50% 피크 FLOPs 활용률(종종 MFU라고 함)에서 4.59e14 FLOPs를 실행하여, 약 17일 만에 훈련할 수 있습니다. 나쁘지 않습니다! 하지만 더 잘할 수 있는 방법을 알아봅시다.

Note on critical batch size: 다소 직관적이지 않게도, 총 배치 크기가 줄어들수록(고정된 칩 수에서) 통신 병목 현상이 더 많이 발생합니다. 데이터 병렬 처리와 FSDP를 사용하면 배치 크기를 계속 늘릴 수만 있다면 임의로 많은 칩으로 확장할 수 있습니다! 그러나 실제로 배치 크기가 증가함에 따라 그라디언트가 거의 노이즈가 없어지기 때문에 훈련에서 수확 체감을 보는 경향이 있습니다. 또한 때때로 훈련 불안정을 봅니다. 따라서 “무제한 컴퓨팅 체제”에서 최적의 샤딩 방식을 찾는 게임은 종종 스케일링 법칙에 의해 결정된 고정된 배치 크기와 알려진(큰) 칩 수에서 시작하여, 그 작은 배치 크기를 그렇게 많은 칩에 맞출 수 있는 파티셔닝을 찾는 것을 목표로 합니다.

Tensor Parallelism

Syntax: \(\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]\)(결국 FSDP와 결합하기 위해\(Y\)를 사용합니다)

완전 샤딩된 데이터 병렬 AllReduce에서는 가중치를 칩 간에 이동시킵니다. 모델의 피드포워드 차원을 샤딩하고 레이어 동안 활성화를 이동시킬 수도 있습니다. 이를 “1D 모델 병렬 처리” 또는 Megatron 샤딩이라고 합니다. 이를 통해 pod당 더 작은 효율적인 배치 크기를 확보할 수 있습니다. 아래 그림은 이 방식으로 샤딩된 단일 행렬의 예를 보여줍니다:

Figure: 기본 텐서 병렬 처리의 예. 활성화를 Y에 대해서만 샤딩하고 있으므로(X에 대해 샤딩하는 FSDP와 달리), 활성화를 X에 대해 복제합니다. 표준 구문을 사용하여 이는 A[B, DY] * B[D, FY] -> C[B, FY]입니다. 축약 차원 중 하나에 대해서만 샤딩하고 있으므로, 일반적으로 matmul 전에 활성화 A를 AllGather합니다.

언급했듯이, In[B, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[B, DY] 는 첫 번째 matmul 전에 활성화를 gather해야 함을 의미합니다. 활성화가 가중치보다 작을 때 이것은 ZeRO 샤딩보다 저렴합니다. 이는 일반적으로 일정량의 ZeRO 샤딩이 추가된 경우에만 해당됩니다(gather의 크기를 줄여줌). 이것이 우리가 ZeRO 샤딩과 텐서 병렬 처리를 혼합하는 경향이 있는 이유 중 하나입니다.

다음은 텐서 병렬 처리를 위한 알고리즘입니다!

Tensor Parallelism:

Forward pass: compute Loss[B]

  1. In[B, D] = AllGather(In[B, DY]) (on critical path)
  2. Tmp[B, FY] = In[B, D] *D Win[D, FY] (not sharded along contracting, so no comms)
  3. Out[B, D] {UY} = Tmp[B, FY] *F Wout[FY, D]
  4. Out[B, DY] = ReduceScatter(Out[B, D] {UY}) (on critical path)
  5. Loss[B] = …

Backward pass: compute dWout[FY, D], dWin[D, FY]

  1. dOut[B, DY] = …
  2. dOut[B, D] = AllGather(dOut[B, DY]) (on critical path)
  3. dWout[FY, D] = Tmp[B, FY] *B dOut[B, D]
  4. dTmp[B, FY] = dOut[B, D] *D Wout[FY, D] (can throw away dOut[B, D] here)
  5. In[B, D] = AllGather(In[B, DY]) (this can be skipped by sharing with (1) from the forward pass)
  6. dWin[D, FY] = dTmp[B, FY] *B In[B, D]
  7. dIn[B, D] {U.Y} = dTmp[B, FY] *F Win[D, FY] (needed for previous layers)
  8. dIn[B, DY] = ReduceScatter(dIn[B, D] {U.Y}) (on critical path)

텐서 병렬 처리의 좋은 점 중 하나는 Transformer 순방향 패스의 두 행렬과 잘 상호 작용한다는 것입니다. 순진하게는 두 행렬 각각 뒤에 AllReduce를 수행할 것입니다. 그러나 여기서는 먼저 In[B, DY] * Win[D, FY] -> Tmp[B, FY] 를 수행한 다음 Tmp[B, FY] * Wout[FY, D] -> Out[B, DY] 를 수행합니다. 이는 AllReduce를 수행하는 대신 시작 부분에서 In을 AllGather하고 끝 부분에서 Out을 ReduceScatter한다는 것을 의미합니다.

얼마나 비용이 들까요? 순방향 패스만 모델링해 보겠습니다. 역방향 패스는 여기서 각 연산의 전치일 뿐입니다. 1D 텐서 병렬 처리에서는 첫 번째 matmul 전에 활성화를 AllGather하고, 두 번째 후에 ReduceScatter하며, 한 번에 2바이트(bf16)를 보냅니다. 통신에 의해 병목이 발생하는 시점을 알아봅시다.

\[\begin{align} T_\text{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\ T_\text{comms} & = T_\text{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\ T_\text{comms} & = \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\\ \textnormal{T} & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}, \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\right) \end{align}\]

연산 비용이 통신 비용보다 크기를 원하므로 다음을 얻습니다:

\[\begin{align} \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} > \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}} \end{align}\] \[\begin{align} \frac{F}{Y \cdot C} > \frac{1}{W_\text{ici}} \end{align}\] \[\begin{align} F > Y \cdot \frac{C}{W_\text{ici}} \end{align}\]

따라서 예를 들어, TPUv5p의 경우 bf16에서 $C / W_{ici} = 2550$이므로, $Y < F / 2550$까지만 텐서 병렬 처리를 수행할 수 있습니다. 여러 ICI 축이 있는 경우 $T_\text{comms}$가 $M_Y$배만큼 감소하므로 $Y < M_Y \cdot F / 2550$을 얻습니다.

Takeaway: 텐서 병렬 처리는 $Y > M_Y \cdot F / 2550$일 때 통신 병목이 발생합니다. 대부분의 모델에서 이는 8에서 16방향 텐서 병렬 처리 사이입니다.

이것은 연산 정밀도에 의존하지 않는다는 점에 유의하세요. 예를 들어 int8의 경우 TPUv5p에서 \(C_\text{int8} / W_{ici}\)는 \(2550\)대신\(5100\)이지만 통신량도 절반으로 줄어들므로 두 배수 요소가 상쇄됩니다.

몇 가지 예를 생각해 봅시다:

Combining FSDP and Tensor Parallelism

Combining FSDP and Tensor Parallelism

Syntax: \(\text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y]\)

FSDP와 텐서 병렬 처리의 좋은 점은 결합될 수 있다는 것입니다. WinWout 을 두 축 모두에 샤딩함으로써 메모리와 연산을 모두 절약합니다. X를 따라 B를 샤딩하기 때문에 모델 병렬 AllGather의 크기를 줄이고, Y를 따라 F를 샤딩하기 때문에 FSDP의 통신 오버헤드를 줄입니다. 이는 두 가지를 결합하면 위에서 본 것보다 훨씬 더 낮은 유효 배치 크기에 도달할 수 있음을 의미합니다.

Figure: FSDP와 텐서 병렬 처리를 결합한 다이어그램. 다른 경우와 달리 모델 파라미터의 중복이 없습니다.
다음은 혼합 FSDP + 텐서 병렬 처리를 위한 전체 알고리즘입니다. 통신이 많지만, 활성화를 배치 샤딩하고 가중치를 텐서 샤딩했기 때문에 모든 AllGather와 ReduceScatter가 더 작습니다!

Forward pass: compute Loss[B]

  1. In[BX, D] = AllGatherY(In[BX, DY]) (on critical path)
  2. Win[D, FY] = AllGatherX(Win[DX, FY]) (can be done ahead of time)
  3. Tmp[BX, FY] = In[BX, D] *D Win[D, FY]
  4. Wout[FY, D] = AllGatherX(Wout[FY, DX]) (can be done ahead of time)
  5. Out[BX, D] {U.Y} = Tmp[BX, FY] *F Wout[FY, D]
  6. Out[BX, DY] = ReduceScatterY(Out[BX, D] {U.Y}) (on critical path)
  7. Loss[BX] = …

Backward pass: compute dWout[FY, DX], dWin[DX, FY]

  1. dOut[BX, DY] = …
  2. dOut[BX, D] = AllGatherY(dOut[BX, DY]) (on critical path)
  3. dWout[FY, D] {U.X} = Tmp[BX, FY] *B dOut[BX, D]
  4. dWout[FY, DX] = ReduceScatterX(dWout[FY, D] {U.X})
  5. Wout[FY, D] = AllGatherX(Wout[FY, DX]) (can be done ahead of time)
  6. dTmp[BX, FY] = dOut[BX, D] *D Wout[FY, D] (can throw away dOut[B, D] here)
  7. In[BX, D] = AllGatherY(In[BX, DY]) (not on critical path + this can be shared with (2) from the previous layer)
  8. dWin[D, FY] {U.X} = dTmp[BX, FY] *B In[BX, D]
  9. dWin[DX, FY] = ReduceScatterX(dWin[D, FY] {U.X})
  10. Win[D, FY] = AllGatherX(Win[DX, FY]) (can be done ahead of time)
  11. dIn[BX, D] {U.Y} = dTmp[BX, FY] *F Win[D, FY] (needed for previous layers)
  12. dIn[BX, DY] = ReduceScatterY(dIn[BX, D] {U.Y}) (on critical path)

FSDP와 TP의 올바른 조합은 무엇일까요? 간단하지만 핵심적인 격언은 FSDP는 가중치를 이동시키고 텐서 병렬 처리는 활성화를 이동시킨다는 것입니다. 즉, 배치 크기가 줄어들수록(특히 데이터 병렬 처리를 더 많이 할수록) 샤드당 활성화가 작아지기 때문에 텐서 병렬 처리가 더 저렴해집니다.

따라서 두 가지를 결합하면 복제본당 최소 배치 크기를 훨씬 더 낮출 수 있습니다. 위와 같은 방식으로 최적의 FSDP 및 TP 양을 계산할 수 있습니다:

\(X\)를 FSDP에 할당된 칩 수, \(Y\)를 텐서 병렬 처리에 할당된 칩 수라고 합시다. \(N\)을 슬라이스의 총 칩 수라고 하고, \(N=XY\)입니다. \(M_X\)와 \(M_Y\)를 각각 FSDP와 TP를 수행하는 메시 축의 수라고 합시다(이들의 합은 대략 3이어야 함). FLOP당 통신이 가장 많은 순방향 패스만 모델링하겠습니다. 그러면 위 알고리즘의 통신을 더하면 다음과 같습니다.

\[T_\text{FSDP comms}(B, X, Y) = \frac{2\cdot 2\cdot D \cdot F}{Y \cdot W_\text{ici} \cdot M_X}\]

\(T_\text{TP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y}\) \(T_\text{TP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y}\)

마찬가지로 총 FLOPs 시간은 다음과 같습니다.

\[T_\text{math} = \frac{2\cdot 2 \cdot B \cdot D \cdot F}{N \cdot C}.\]

분석을 단순화하기 위해 두 가지 가정을 합니다. 첫째, $X$와 $Y$가 정수가 아닌 값을 가질 수 있도록 허용합니다($XY=N$을 만족하는 양수인 한). 둘째, $X$와 $Y$ 축의 통신을 서로 완전히 중첩할 수 있다고 가정합니다. 두 번째 가정하에 총 통신 시간은 다음과 같습니다.

\[T_\text{comms} = \max\left(T_\text{FSDP comms}, T_\text{TP comms}\right)\]

어떤 조건에서 compute-bound가 될지 묻기 전에, 총 통신을 최소화하기 위한 $X$와 $Y$의 최적 값을 찾아봅시다. FLOPs는 $X$와 $Y$에 독립적이므로, 최적의 설정은 단순히 통신을 최소화하는 것입니다. 이를 위해 $T_\text{comms}$를 $X$와 $Y$가 아닌 $X$와 $N$(시스템의 칩 수이므로 고정됨)으로 다시 작성해 보겠습니다.

\[T_\text{comms} (X) = \frac{4D}{W_\text{ici}} \max\left(\frac{F \cdot X}{N \cdot M_X}, \frac{B}{X \cdot M_Y}\right)\]

$T_\text{FSDP comms}$는 $X$에 대해 단조 증가하고 $T_\text{TP comms}$는 $X$에 대해 단조 감소하므로, 최대값은 $T_\text{FSDP comms} = T_\text{TP comms}$일 때 최소화되어야 합니다. 이는 다음과 같을 때 발생합니다.

\[\begin{align*} \frac{FX_{opt}}{M_X} = \frac{BN}{X_{opt} M_Y} \rightarrow \\ \frac{FX_{opt}}{M_X} = \frac{BN}{X_{opt} M_Y} \rightarrow \\ X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} \end{align*}\]

이것은 매우 유용합니다! 주어진 $B$, $F$, $N$에 대해 최적의 FSDP 양을 알려줍니다. 규모 감각을 얻어 봅시다. 현실적인 값, 즉 $N = 64$ (4x4x4 칩 배열에 해당), $B=48,000$, $F=32768$을 대입하면 대략 $X\approx 13.9$를 얻습니다. 따라서 $X$를 16으로, $Y$를 4로 선택하여 계산된 최적값에 가깝게 할 것입니다.

Takeaway: 일반적으로 훈련 중에 최적의 FSDP 양은 \(X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N}\)입니다.

이제 모든 병렬 처리 전략에 대해 질문했던 것으로 돌아가 봅시다: 어떤 조건에서 compute-bound가 될까요? FLOPs와 통신을 중첩할 수 있으므로, 다음과 같을 때 compute-bound입니다.

\(\max\left(T_\text{FSDP comms}, T_\text{TP comms}\right) < T_\text{math}\) \(\max\left(T_\text{FSDP comms}, T_\text{TP comms}\right) < T_\text{math}\)

$\alpha \equiv C / W_\text{ici}$, 즉 ICI 연산 강도로 놓으면 단순화할 수 있습니다:

\(\max\left(\frac{F}{Y \cdot M_X}, \frac{B}{X \cdot M_Y}\right) < \frac{B \cdot F}{N \cdot \alpha}\) \(\max\left(\frac{F}{Y \cdot M_X}, \frac{B}{X \cdot M_Y}\right) < \frac{B \cdot F}{N \cdot \alpha}\)

좌변의 최대값이 같도록 $X_{opt}$를 계산했으므로 양쪽에 대입할 수 있습니다($Y_{opt} = N/X_{opt}$임에 유의).

\(\frac{F}{N \cdot W_\text{ici} \cdot M_X} \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} < \frac{B \cdot F}{N \cdot C}\) \(\frac{F}{N \cdot W_\text{ici} \cdot M_X} \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} < \frac{B \cdot F}{N \cdot C}\)

더 단순화하면 다음을 얻습니다.

\(\sqrt{\frac{B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha},\) \(\sqrt{\frac{B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha},\)

여기서 좌변은 통신 시간에 비례하고 우변은 계산 시간에 비례합니다. 계산 시간은 배치 크기에 선형적으로 비례하는 반면(병렬 처리에 관계없이), 통신 시간은 배치 크기의 제곱근으로 비례한다는 점에 유의하세요. 따라서 계산 시간 대 통신 시간의 비율도 배치 크기의 제곱으로 비례합니다:

\[\frac{T_\text{math}}{T_\text{comms}} = \frac{\sqrt{BF}\sqrt{M_X M_Y}}{\alpha \sqrt{N}}.\]

이 비율이 1보다 커서 compute bound가 되도록 하려면 다음이 필요합니다.

\(\frac{B}{N} > \frac{\alpha^2}{M_X M_Y F}\) \(\frac{B}{N} > \frac{\alpha^2}{M_X M_Y F}\)

대략적인 숫자를 얻으려면 다시 $F=32,768$, $\alpha=2550$, $M_X M_Y=2$(3D 메시여야 하므로)를 대입합니다. 이는 대략 $B/N > 99$를 제공합니다. 이는 3D 메시를 가정할 때 $B/N$이 약 $850$을 초과해야 compute bound가 되는 순수 데이터 병렬 처리(또는 FSDP) 경우에 비해 대략 8배의 이득을 얻습니다.

Takeaway: 텐서 병렬 처리를 FSDP와 결합하면 $B/N$을 \(2550^2 / 2F\)로 낮출 수 있습니다. 이를 통해 칩당 100만큼 적은 배치를 처리할 수 있으며, 이는 FSDP만 사용할 때보다 대략 8배 더 작습니다.

아래에서 대표적인 4x4x4 칩 배열에 대해 텐서 병렬 처리(TP)만 사용하는 경우와 데이터 병렬 처리(FSDP)만 사용하는 경우와 비교하여 혼합 FSDP + TP의 FLOPs 대 통신 시간 비율을 플롯합니다. 순수 FSDP 병렬 처리가 매우 큰 배치 크기에서 우세하지만, 칩 수 대비 배치 크기가 대략 100에서 850 사이인 영역에서는 FSDP + TP 혼합 전략만이 1보다 큰 비율을 달성합니다.

Figure: F=30k인 TPUv5p 4x4x4 슬라이스에서 최적의 혼합 FSDP/TP에 대한 FLOPs 대 통신 시간 비율. 예상대로 텐서 병렬 처리는 배치 크기에 따라 고정된 비율을 갖습니다. 이상적인 혼합 FSDP + TP는 $\sqrt{B}$로 스케일링되고 FSDP는 $B$로 스케일링됩니다. 그러나 중간 배치 크기 영역에서는 FSDP + TP만이 1보다 큰 비율을 달성합니다.

다음은 TPU v5p 16x16x16의 또 다른 예로, 다양한 샤딩 방식에 대한 배치 크기 함수로서의 FLOPs 및 통신 시간을 보여줍니다.

Figure: 다양한 병렬 처리 방식에 따른 통신 소요 시간. 검은 점선은 행렬 곱셈 FLOPs에 소요되는 시간이므로, 이 선 위의 모든 곡선은 comms-bound입니다. 모든 전략이 배치 크기 6e5 미만에서 comms-bound가 되는 것을 알 수 있는데, 이는 예상치인 4096 * 2550^2 / (2 * 8192 * 4) = 4e5와 일치합니다.

검은색 곡선은 모델 FLOPs에 소비된 시간의 양을 의미하며, 이것이 모든 통신 비용보다 낮은 배치 크기는 엄격하게 comms bound입니다. 검은색 곡선이 녹색 곡선과 예측대로 약 4e5에서 교차하는 것을 알 수 있습니다.

다음은 다양한 배치 크기에 대한 총 계산 시간과 통신 시간을 보여주는 대화형 애니메이션입니다:

이것이 위와 대체로 일치함(FSDP=256, TP=16 부근에서 최소)을 알 수 있으며, 각각에 대한 축 수의 약간의 차이로 인한 약간의 오차 범위가 있습니다.

Pipelining

아마도 이전 섹션에서 파이프라이닝에 대해 이야기하는 것을 피했다는 것을 눈치채셨을 것입니다. 파이프라이닝은 TPU에서는 다소 덜 필수적인 GPU 병렬 처리를 위한 지배적인 전략입니다. 간단히 말해서, 파이프라인 훈련은 모델의 레이어를 여러 디바이스에 분할하고 순방향 및 역방향 패스 중에 파이프라인 단계 간에 활성화를 전달하는 것을 포함합니다. 알고리즘은 다음과 같습니다:

  1. TPU 0에서 레이어 차원에 걸쳐 샤딩된 가중치로 데이터를 초기화합니다(FSDP 및 텐서 병렬 처리가 있는 파이프라이닝의 경우 $W_\text{in}[L_Z, D_X, F_Y]$).
  2. TPU 0에서 첫 번째 레이어를 수행한 다음 결과 활성화를 TPU 1로 복사하고 마지막 TPU에 도달할 때까지 반복합니다.
  3. 손실 함수와 그 미분 $\partial L / \partial x_L$을 계산합니다.
  4. 마지막 파이프라인 단계에 대해 미분 $\partial L / \partial W_L$ 및 $\partial L / \partial x_{L-1}$을 계산한 다음 $\partial L / \partial x_{L-1}$을 이전 파이프라인 단계로 복사하고 TPU 0에 도달할 때까지 반복합니다.
다음은 (작동하는) Python 의사 코드입니다.

이 의사 코드는 Cloud TPU VM에서 실행되어야 합니다. 매우 효율적이거나 현실적이지는 않지만 데이터가 디바이스 간에 어떻게 전파되는지에 대한 감각을 제공합니다.

import jax
import jax.numpy as jnp

batch_size = 32
d_model = 128
d_ff = 4 * d_model

num_layers = len(jax.devices())

key = jax.random.PRNGKey(0)

# Pretend each layer is just a single matmul.
x = jax.random.normal(key, (batch_size, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))

def layer_fn(x, weight):
  return x @ weight

# Assume we have num_layers == num_pipeline_stages
intermediates = [x]
for i in range(num_layers):
  x = layer_fn(x, weights[i])
  intermediates.append(x)

  if i != num_layers - 1:
    x = jax.device_put(x, jax.devices()[i+1])

def loss_fn(batch):
  return jnp.mean(batch ** 2)  # make up some fake loss function

loss, dx = jax.value_and_grad(loss_fn)(x)

for i in range(0, num_layers, -1):
  _, f_vjp = jax.vjp(layer_fn, intermediates[i + 1], weights[i])
  dx, dw = f_vjp(dx)  # compute the jvp dx @ J(L)(x[i], W[i])
  weights[i] = weights[i] - 0.01 * dw  # update our weights

  if i != 0:
    dx = jax.device_put(dx, jax.devices()[i-1])

왜 이것이 좋은 생각일까요? 파이프라이닝은 여러 가지 이유로 훌륭합니다. 파이프라인 단계 간의 통신 비용이 낮으므로 낮은 대역폭 인터커넥트로도 매우 큰 모델을 훈련할 수 있습니다. 이는 TPU처럼 ICI로 밀집 연결되지 않은 GPU에서 종종 매우 유용합니다.

왜 이것이 어렵고/성가실까요? 위의 의사 코드에서 TPU 0이 거의 항상 유휴 상태라는 것을 눈치채셨을 것입니다! 파이프라인의 맨 처음과 마지막 단계에서만 작업을 수행합니다. 유휴 기간을 파이프라인 버블이라고 하며 처리하기가 매우 성가십니다. 일반적으로 마이크로배칭을 사용하여 먼저 이를 완화하려고 시도합니다. 이는 여러 개의 작은 배치를 파이프라인을 통해 보내 TPU 0이 전체 스텝 시간의 더 큰 부분 동안 활용되도록 유지합니다.

두 번째 접근 방식은 순방향 matmul $W_i @ x_i$, 역방향 $dx$ matmul $W_i @ \partial L / \partial x_{i+1}$, 그리고 $dW$ matmul $\partial L / \partial x_{i+1} @ x_i$를 신중하게 중첩하는 것입니다. 이들 각각은 약간의 FLOPs를 필요로 하므로 중첩하여 버블을 완전히 숨길 수 있습니다. 다음은 최근 DeepSeek v3 논문의 플롯으로, 그들의 “bubble-free” 파이프라인 일정을 보여줍니다:

Figure: DeepSeek v3 파이프라인 일정 (최근 논문에서). 주황색은 순방향 matmul, 녹색은 dL/dx matmul, 파란색은 dL/dW matmul입니다. 역방향 dL/dx 곱셈의 우선순위를 지정하여 "좌초된(stranding)" FLOPs를 피할 수 있습니다.

TPU(상호 연결된 더 큰 pod를 가짐)에는 덜 중요하기 때문에 이에 대해 깊이 파고들지는 않겠지만, 주요 파이프라이닝 병목 현상을 이해하는 것은 좋은 연습입니다.

Scaling Across Pods

가장 큰 가능한 TPU 슬라이스는 8960개의 칩(및 2240개의 호스트)이 있는 TPU v5p SuperPod입니다. 이 크기를 넘어 확장하려면 데이터 센터 네트워킹(DCN) 경계를 넘어야 합니다. 각 TPU 호스트에는 이더넷을 통해 다른 TPU v5p pod에 호스트를 연결하는 하나 또는 여러 개의 NIC(네트워크 인터페이스 카드)가 장착되어 있습니다. TPU 섹션에서 언급했듯이, 각 호스트는 약 200Gbps(25GB/s)의 전이중(full-duplex) DCN 대역폭을 가지며, 이는 TPU당 약 6.25GB/s 전이중(송신) 대역폭입니다.

일반적으로 단일 pod를 넘어 확장할 때, ICI 도메인 내에서는 어떤 형태의 모델 병렬 처리 또는 FSDP를 수행하고, 여러 pod에 걸쳐서는 순수 데이터 병렬 처리를 수행합니다. $N$을 확장하려는 TPU 수, $M$을 ICI 연결 슬라이스당 TPU 수라고 합시다. DCN을 통해 AllReduce를 수행하려면 pod 세트에 대해 링 축소(ring-reduction)를 수행할 수 있습니다(역방향 패스에서):

\[T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot BDF}{N \cdot C}\] \[T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot DF}{M \cdot W_\text{dcn}}\]

통신 대역폭은 $M$에 따라 확장됩니다. ICI와 달리 ICI 도메인을 키우고 더 많은 NIC를 확보함에 따라 총 대역폭이 증가하기 때문입니다. 단순화하면 다음과 같을 때 $T_\text{math} > T_\text{comms}$임을 알 수 있습니다.

\[\frac{B}{\text{slice}} > \frac{C}{W_\text{dcn}}\]

TPU v5p의 경우 $\frac{C}{W_\text{dcn}}$은 약 4.46e14 / 6.25e9 = 71,360입니다. 이는 DCN을 통해 효율적으로 확장하려면 각 노드를 송신하는 데 필요한 ICI 도메인당 최소 배치 크기가 있음을 알려줍니다.

이것이 얼마나 문제인가요? 구체적인 예를 들어, BS=2M 토큰으로 TPU v5p에서 LLaMA-3 70B를 훈련하고 싶다고 가정해 봅시다. LLaMA-3 70B는 $F\approx 30,000$입니다. 위의 섹션에서 우리는 다음을 알고 있습니다:

요약하자면 BS=1M, 대략 X(FSDP)=1024 및 Y(TP)=8을 사용하여 훈련할 수 있는 좋은 레시피가 있지만, BS=2M에서는 DCN을 사용해야 합니다. 위에서 언급했듯이 $\text{71,360}$의 DCN 연산 강도를 가지고 있으므로, ICI 도메인당 배치 크기가 이보다 큰지 확인하기만 하면 됩니다. 2개의 pod를 사용하면 pod당 BS가 1M이고 GPU당 배치 크기가 111이므로 이는 우리에게 사소한 문제입니다(약간 아슬아슬할 수 있지만 이론적으로는 건전함).

Takeaway: 여러 TPU pod에 걸친 확장은 pod당 배치 크기가 최소 71k 토큰인 한 순수 데이터 병렬 처리를 사용하여 매우 간단합니다.

Takeaways from LLM Training on TPUs

Strategy Description
Data Parallelism 활성화는 배치 샤딩되고, 다른 모든 것은 완전히 복제되며, 역방향 패스 중에 그라디언트를 all-reduce합니다.
FSDP 활성화, 가중치 및 옵티마이저는 배치 샤딩되며, 가중치는 사용 직전에 gather되고, 그라디언트는 reduce-scatter됩니다.
Tensor Parallelism (aka Megatron, Model) 활성화는 \(d_\text{model}\)에 따라 샤딩되고, 가중치는 \(d_{ff}\)에 따라 샤딩되며, 활성화는 Win 전에 gather되고, 결과는 Wout 후에 reduce-scatter됩니다.
Mixed FSDP + Tensor Parallelism 위의 두 가지 모두이며, FSDP는 모델 샤딩된 가중치를 gather합니다.

그리고 각 방법에 대한 “공식”은 다음과 같습니다:

\[\small \begin{array}{cc} \text{Strategy} & \text{Formula}\\ \hline \text{DP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D] \\ \text{FSDP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D] \\ \text{TP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \\ \text{TP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \\ \text{TP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \\ \text{TP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \\ \hline \end{array}\] \[\small \begin{array}{ccc} \text{Strategy} & \text{Compute per layer} & \text{Comms per layer} \\ & \text{(ignoring gating einsum)} & \text{(bytes, forward + backward pass)}\\ \hline \text{DP} & 4BDF/X + 8BDF/X & 0 + 8DF \\ \text{FSDP} & 4BDF/X + 8BDF/X & 4DF + 8DF \\ \text{TP} & 4BDF/Y + 8BDF/Y & 4BD + 4BD \\ \text{FSDP + TP} & 4BDF/(XY) + 8BDF/(XY) & (4BD/X + 4DF/Y) + (8BD/X + 8DF/Y) \\ \text{TP} & 4BDF/Y + 8BDF/Y & 4BD + 4BD \\ \text{FSDP + TP} & 4BDF/(XY) + 8BDF/(XY) & (4BD/X + 4DF/Y) + (8BD/X + 8DF/Y) \\ \hline \end{array}\]

Some Problems to Work

이 섹션의 기본 모델로 LLaMA-2 13B를 사용합시다. 모델 세부 정보는 다음과 같습니다:

hyperparam value
L 40
D 5,120
F 13824
N 40
K 40
H 128
V 32,000

LLaMA-2에는 별도의 임베딩 및 출력 행렬과 게이트형 MLP 블록이 있습니다.

Question 1: LLaMA-2 13B에는 얼마나 많은 파라미터가 있나요(어리석은 질문인 건 알지만 계산해 보세요)? Transformer Math에서와 같이 LLaMA-3에는 3개의 큰 FFW 행렬, 2개의 up-projection 및 1개의 down-projection이 있습니다. 이 섹션에서는 두 개의 “게이팅” einsum 행렬을 무시했지만, 이 섹션에서는 Win과 동일하게 동작합니다.

답을 보려면 여기를 클릭하세요.
  • FFW 파라미터: \(3LDF\) = 8.5e9
  • Attention 파라미터: \(4DNHL\) = 4.2e9
  • Vocabulary 파라미터: \(2VD\) = 0.3e9
  • 합계: 8.5e9 + 4.2e9 + 0.39e9 = 13.1e9, 예상대로입니다!

Question 2: BS=16M 토큰으로 훈련하고 Adam을 사용한다고 가정해 봅시다. 잠시 병렬 처리를 무시하고, 모델의 파라미터, 옵티마이저 상태, 활성화에 사용되는 총 메모리는 얼마인가요? 파라미터는 bf16에, 옵티마이저 상태는 fp32에 저장하고 레이어당 세 번(세 개의 큰 matmul 이후) 활성화를 체크포인트한다고 가정합니다.

답을 보려면 여기를 클릭하세요.

파라미터(bf16)와 두 개의 옵티마이저 상태(fp32, 1차 및 2차 모멘트 누적기)에 사용되는 총 메모리는 (2 + 4 + 4) * 13e9 ~ 130GB입니다. 처음 두 개의 matmul 후의 활성화는 모양이 $BF$이고 마지막 후에는 $BD$이므로(위의 Transformer 다이어그램에 따라), bf16의 총 메모리는 $2 \cdot L \cdot (BD + 2 * BF) = 2LB \cdot (D + 2F)$ 또는 2 * 40 * 16e6 * 5,120 * (1 + 2 * 2.7) ~ 4.2e13 = 42TB입니다(B=16e6이므로). 다른 모든 활성화는 다소 무시할 수 있습니다.

Question 3: TPUv5p 16x16x16 슬라이스에서 32k 시퀀스 길이와 총 배치 크기 3M 토큰으로 훈련하고 싶다고 가정해 봅시다. 위와 같이 bfloat16 가중치와 float32 옵티마이저를 사용하고 싶다고 가정합니다.

  1. 순수 데이터 병렬 처리를 사용할 수 있나요? 그 이유는 무엇인가요?
  2. 순수 FSDP를 사용할 수 있나요? 그 이유는 무엇인가요? 순수 FSDP를 사용하면 디바이스당 얼마나 많은 메모리가 사용되나요(3개의 큰 FFW 행렬 후에만 gradient checkpointing을 수행한다고 가정).
  3. 혼합 FSDP + 텐서 병렬 처리를 사용할 수 있나요? 그 이유는 무엇인가요? 그렇다면 $X$와 $Y$는 무엇이어야 하나요? 디바이스당 얼마나 많은 메모리가 저장되나요? 루프라인 FLOPs 추정치만 사용하고 어텐션을 무시할 때, 40% MFU에서 각 훈련 단계는 얼마나 걸리나요?
답을 보려면 여기를 클릭하세요.

먼저 몇 가지 숫자를 적어 봅시다. 32k 시퀀스 길이와 3M 배치 크기로, 시퀀스 배치 크기는 96입니다. TPU v5p 16x16x16 슬라이스에는 총 393TB의 HBM이 있습니다.

  1. 순수 데이터 병렬 처리는 사용할 수 없습니다. 각 칩에 파라미터와 옵티마이저 상태를 복제하는데, 이는 이미 약 130GB(Q2에서)로 칩당 HBM(96GB)보다 많기 때문입니다.

  2. 메모리만 보는 것으로 시작해 봅시다. Q2에서 BS=16M을 3M으로 바꾸면 ~7.86e12 총 체크포인트 활성화를 얻고, 1.3e11 옵티마이저 상태를 더하면 거의 정확히 8e12 = 8TB가 됩니다. TPUv5p 슬라이스는 총 393TB의 HBM을 가지고 있으므로 HBM 제한 아래에 안전하게 있습니다. 다음으로 comms-bound인지 compute-bound인지 살펴봅시다. 4096개의 칩과 3개의 병렬 처리 축으로 850 * 4096 = 3.48M 토큰의 최소 배치 크기를 수행할 수 있습니다. 이는 3M 배치 크기보다 약간 높습니다. 따라서 실제로 comms-bound이며, 이는 슬픈 일입니다. 따라서 일반적인 대답은 아니요, FSDP만으로는 할 수 없습니다.

  3. 이제 우리의 주된 관심사가 comms-bound라는 것을 알았으므로 숫자를 대입해 봅시다. 우선 위에서 혼합 FSDP + 텐서 병렬 처리를 사용하는 칩당 배치 크기가 $2550^2 / 2F = 235$ 이상이어야 한다는 것을 알고 있습니다. 즉, 이론적으로 이것을 할 수 있습니다! 각각 얼마인지 알아봅시다.

$X_{opt} = \sqrt((F / B) * (M_X / M_Y) * N)$ 규칙이 있으므로, 여기서는 sqrt(3e6 * 2 * 4096 / 13824) = 1333을 가지며, 이는 대략 1024 방향 DP와 4 방향 TP를 수행한다는 것을 의미합니다. TPU당 메모리는 (2)와 같을 것이며, 스텝 시간은 6 * 3e6 * 13e9 / (4096 * 4.6e14 * 0.4) = 300ms가 될 것입니다.

파트 5는 여기까지입니다! 실제 LLaMA 모델에 이 내용을 적용하는 파트 6을 보려면, 여기를 클릭하세요!

Appendix

Appendix A: Deriving the backward pass comms

위에서 Transformer 레이어 순방향 패스를 Out[B, D] = In[B, D] *D Win[D, F] *F Wout[F, D]로 단순화했습니다. 역방향 패스에 필요한 통신은 어떻게 유도할까요?

이는 단일 matmul Y = X * A에 대한 이전 섹션의 규칙에서 아주 자연스럽게 따릅니다:

\[\frac{dL}{dA} = \frac{dL}{dY}\frac{dY}{dA} = X^T \left(\frac{dL}{dY}\right)\] \[\frac{dL}{dX} = \frac{dL}{dY}\frac{dY}{dX} = \left(\frac{dL}{dY}\right) A^T\]

이를 사용하여 다음 공식을 얻습니다(Tmp[B, F]는 In[B, D] * Win[D, F]를 나타냄):

  1. dWout[F, D] = Tmp[B, F] *B dOut[B, D]
  2. dTmp[B, F] = dOut[B, D] *D Wout[F, D]
  3. dWin = dTmp[B, F] *B Tmp[B, F]
  4. dIn[B, D] = dTmp[B, F] *F Win[D, F]
  5. dWout[F, D] = Tmp[B, F] *B dOut[B, D]
  6. dTmp[B, F] = dOut[B, D] *D Wout[F, D]
  7. dWin = dTmp[B, F] *B Tmp[B, F]
  8. dIn[B, D] = dTmp[B, F] *F Win[D, F]

이 공식들은 샤딩에 대한 언급이 없는 수학적 진술이라는 점에 유의하세요. 역방향 패스의 작업은 이 네 가지 수량을 계산하는 것입니다. 따라서 필요한 통신을 파악하려면 위의 네 가지 방정식(Tmp, dOut, Wout, Win)에서 matmul될 모든 수량의 샤딩(병렬화 방식에 의해 지정됨)을 가져와 샤딩된 matmul의 규칙을 사용하여 수행해야 할 통신을 파악하면 됩니다. dOut은 Out과 같은 방식으로 샤딩된다는 점에 유의하세요.

Miscellaneous

*Work done at Google DeepMind, now at MatX.

Citation

For attribution in academic contexts, please cite this work as:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

or as a BibTeX entry:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }