Part 8 of How To Scale Your Model (Part 7: Inference | Part 9: Profiling)
TPU v5e에서 LLaMA 3-70B 모델을 서빙하는 방법을 자세히 살펴보겠습니다. roofline에서 서빙하는 데 비용이 얼마나 들까요? KV 캐시의 크기는 얼마일까요? 어떤 배치 크기를 사용해야 할까요? 추론 중 파라미터와 활성화는 어떻게 샤딩될까요? 프로덕션 환경에서의 지연 시간과 처리량에 대한 대략적인 추정치를 계산해 보겠습니다.
번역 안내: 원저자(Jacob Austin)의 허락을 받아 원문을 번역 중입니다.
해당 글의 1인칭은 원문 저자를 지칭합니다.
원문: How to Scale Your Model
번역: 신종훈
이 섹션에서는 LLaMA-3를 서빙하는 데 필요한 것과 얼마나 효율적으로 수행할 수 있는지 살펴볼 것입니다. 이전의 “applied” 섹션과 마찬가지로, 정답을 찾아보기 전에 펜과 종이를 가지고 스스로 답을 찾아보세요!
LLaMA 3-70B가 어떻게 생겼는지 상기해 봅시다 (섹션 6 참조):
| hyperparam | value |
|---|---|
| \(n_\text{layers}\) (L) | 80 |
| \(d_\text{model}\) (D) | 8,192 |
| \(d_{ff}\) (F) | 28,672 |
| \(n_\text{heads}\) (N) | 64 |
| \(n_\text{kv heads}\) (K) | 8 |
| \(d_\text{qkv}\) (H) | 128 |
| \(n_\text{embeddings}\) (V) | 128,256 |
간단한 질문으로 시작해 봅시다: 어떤 하드웨어에서 서빙해야 할까요? 정답은 기본적으로 FLOPs / dollar가 가장 저렴한 것입니다.
| TPU type | bfloat16 FLOPs/s | Google Cloud USD / hour | FLOPs / $ |
|---|---|---|---|
| H100 | 9.9e14 | $10.8 | 3.3e17 |
| v5p | 4.59e14 | $4.2 | 3.9e17 |
| v5e | 1.97e14 | $1.2 | 5.8e17 |
각 TPU v5e는 16GB의 HBM을 가지고 있어 모델을 상당히 공격적으로 샤딩해야 합니다. 우리에게 중요할 수 있는 몇 가지 기본 수량을 생각하는 것부터 시작해 봅시다:
Question: LLaMA 3-70B의 토큰당 KV 캐시는 얼마나 큰가요? int8로 저장한다고 가정할 수 있습니다. 이는 주어진 토폴로지에서 배치 크기가 얼마나 클 수 있는지를 결정합니다.
LLaMA 3-70B는 8개의 KV 헤드를 가지고 있으므로 토큰당 크기는 2 * K * H * L = 2 * 8 * 128 * 80 = 160kB입니다.
이것이 얼마나 큰지 주목하세요! 시퀀스 길이가 32k 토큰인 경우(일반적임) 162e3 * 32,768 = 5.3GB / sequence를 사용합니다. BS=240의 경우 이는 1.3TB입니다! TPU v5e는 각각 16GB만 가지고 있으므로 이만큼의 메모리를 맞추려면 약 (70e9 + 1.3e12) / 16e9 = 86개의 TPU v5e 칩이 필요합니다. 또한 이것이 70GB의 모델 파라미터에 비해 얼마나 큰지 주목하세요.
Question: L3 70B를 배치 크기 32, 시퀀스 길이 8192, 모든 것(파라미터 및 KV)을 int8로 서빙하고 싶다고 가정해 봅시다. 총 메모리는 얼마나 사용될까요? 이것을 서빙할 수 있는 가장 작은 슬라이스는 무엇인가요?
KV는 int8에서 160e3 바이트이므로 총 KV 메모리는 160e3 * 8192 * 32 = 41.9e9 바이트입니다. 파라미터당 1바이트이므로 파라미터는 70e9 바이트입니다. 따라서 총 메모리 사용량은 41.9e9 + 70e9 = 112GB입니다.
사용할 수 있는 가장 작은 슬라이스는 112e9 / 16e9 = 7개의 TPU, 즉 (짝수 크기로 반올림하여) TPU v5e 4x2가 될 것입니다. 이것은 빡빡할 수 있으며 다른 오버헤드를 고려할 때 딱 맞지 않을 수도 있으므로 최소한 4x4가 필요할 수 있습니다(또는 배치 크기를 줄여야 함).
Question: TPU v5e 4x2에서 이 배치 크기와 양자화로, 디코드 단계당 대략 어떤 지연 시간을 예상할 수 있나요? 처리량(tokens / sec / chip)은 어떤가요? 4x4는 어떨까요? FLOPs는 bfloat16에서 수행되고 모든 것이 완전히 샤딩되었다고 가정합니다.
이전 섹션의 공식을 호출할 수 있습니다.
\[\begin{align*} \tiny \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\tiny \text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\tiny \text{MLP (can be compute-bound)}} \end{align*}\]여기서 파라미터는 int8이지만 FLOPs는 bfloat16이므로 임계 배치 크기는 약 120이 될 것입니다. RHS 최대값을 수동으로 계산할 수도 있지만, 이는 기본적으로 우리가 이미 여러 번 수행한 계산입니다. 따라서 우리는 matmul과 FLOPs 모두에 대해 memory-bound 영역에 잘 들어와 있습니다.
엄밀히 메모리 대역폭만 보면, 단계 시간은 기본적으로 (KV size + param size) / (8 * HBM bandwidth) = 112e9 / (8 * 8.1e11) = 17ms입니다. 따라서 이론적으로 단계 시간은 약 17ms입니다. 처리량은 32 / .017 = 1882 tokens / sec 또는 1882 / 8 = 235 tokens / sec / chip입니다.
여기서 한 가지 주의할 점은 matmul에서 ICI bound가 될 수 있는지 확인하는 것입니다. 여기서 2개의 축을 할당할 수 있으므로 이론적으로 $Y > 2 * F / 2200 = 2 * 28672 / 2200 = 26$일 때 ICI bound인데, 우리는 괜찮습니다!
4x4에서 실행하더라도 ICI 측면에서 여전히 괜찮으므로 지연 시간은 17 / 2 = 8.5ms로 떨어지지만 칩당 처리량은 동일하게 유지됩니다.
처리량에 대해서만 생각하는 시간을 가져봅시다. 처리량을 최적화할 때 우리는 compute bound가 되기를 원합니다. 즉, 모든 TPU MXU 용량을 활용하는 데 가까워지는 것입니다. 일반적으로 이는 배치 크기를 가능한 한 크게 하여 가능한 한 많은 작업을 수행하고자 함을 의미합니다.
Question: TPU v5e에서 bfloat16 가중치와 활성화를 사용할 때, matmul에서 compute-bound가 되려면 배치 크기가 얼마나 커야 하나요? int8 가중치를 사용하지만 FLOPs는 bfloat16으로 수행하면 어떨까요? int8 가중치와 int8 FLOPs는 어떨까요?
섹션 7에서 논의했듯이, $B \ll D, F$인 bfloat16 matmul의 경우 다음이 성립합니다.
\[\begin{equation*} T_\text{math} > T_\text{comms} \leftrightarrow \frac{2BDF}{2DF} \geq \frac{\text{TPU bfloat16 FLOPs/s}}{\text{HBM bandwidth}} = 240 \end{equation*}\]가중치가 int8일 때 분모에서 2배를 잃게 되므로 $2BDF / DF = 2B > 240$, 또는 동일하게 $B > 120$이 되며, 이는 이전 임계 배치 크기의 절반입니다. 이것은 우리에게 정말 도움이 됩니다! int8 가중치와 int8 FLOPs를 수행할 때 TPU FLOPs/s에 대해 int8 값을 사용해야 하는데, 이는 bfloat16의 1.97e14에서 거의 두 배인 3.94e14로 증가합니다. 즉, 우리는 다시 약 $B > 240$으로 돌아갑니다.
int8 가중치와 bfloat16 FLOPs의 경우는 꽤 흔합니다. 파라미터를 손실 없이 양자화하는 것이 저정밀도 산술을 수행하는 것보다 종종 더 쉽기 때문입니다.
Question: 8k 컨텍스트로 bfloat16, int8, int4(KV 및 파라미터 모두)를 사용하여 LLaMA 3-70B를 서빙할 수 있는 가장 작은 TPU v5e 토폴로지는 무엇인가요? 이 문제에서 KV 캐시는 무시할 수 있을 만큼 작다고 생각할 수 있습니다.
쉽습니다! 아주 작은 배치 크기에 괜찮다면 유일한 제한은 HBM에 파라미터 메모리를 맞추는 것입니다. 즉, ceil(num_params * sizeof(dtype) / HBM per TPU, 또는 ceil(70e9 * sizeof(dtype) / 16e9)를 가장 가까운 합리적인 토폴로지(2의 배수)로 반올림한 것입니다:
| dtype | param size | KV size / token (bytes) | min TPU v5es | actual min slice | remaining HBM for KV caches | num KV caches @ 8k |
|---|---|---|---|---|---|---|
| bf16 | 140GB | 324kB | 8.75 | 4x4 = 16 chips | 116 | 43 |
| int8 | 70GB | 162kB | 4.38 | 4x2 = 8 chips | 58 | 43 |
| int4 | 35GB | 81kB | 2.81 | 2x2 = 4 chips | 29 | 43 |
| int8 | 70GB | 162kB | 4.38 | 4x2 = 8 chips | 58 | 43 |
| int4 | 35GB | 81kB | 2.81 | 2x2 = 4 chips | 29 | 43 |
정말 멋집니다! 원한다면 LLaMA 70B를 TPU v5e 2x2에 맞출 수 있다는 것을 말해줍니다. 단, KV 캐시의 수가 매우 적다는 것을 알 수 있습니다. 그것이 바로 우리의 배치 크기입니다! 즉, 우리는 끔찍한 FLOPs 활용률을 얻게 될 것입니다. 배치 크기를 240까지 올리기 위해 더 큰 토폴로지를 사용하는 것이 훨씬 더 기쁠 것입니다.
Question: 이 토폴로지들에 맞는 가장 큰 배치 크기를 사용한다고 가정할 때, 각 generate 단계에 대해 어떤 지연 시간을 예상할 수 있을까요?
HBM을 가득 채울 배치 크기를 선택하고 있기 때문에 이것 또한 쉽습니다! 이것은 단지 전체 TPU v5e 분량의 바이트를 MXU로 로드하는 데 얼마나 걸리는지의 문제입니다. 이것은 단지 v5e HBM / v5e HBM memory bandwidth = 16GB / 8.2e11 = 19ms이므로 19ms / step입니다. 생성의 중앙값 길이가 512 토큰이라고 가정하면 각 디코딩에 약 9초가 걸립니다. 더 작은 배치 크기로 약간 더 나은 지연 시간을 얻을 수 있습니다. 예를 들어 int4의 모델 파라미터만 본다면 HBM이 더 이상 꽉 차 있지 않으므로 최소 지연 시간은 약 10ms / step입니다.
Takeaway: HBM에서 MXU로 모든 모델 파라미터를 로드하는 데 걸리는 시간을 물어봄으로써 항상 디코드 지연 시간의 하한을 정할 수 있습니다. KV 캐시가 작을 때, 각 레이어를 가중치를 청크 단위로 로드하고 폐기하는 것으로 생각할 수 있습니다. 큰 배치 크기나 많은 기기 간 통신을 사용하지 않는 한, 이는 종종 합리적인 경계입니다(1.5배 이내). 배치 크기가 클 때는 KV 캐시 로딩이 파라미터를 지배하므로 KV 캐시 로딩도 모델링해야 합니다.
마찬가지로 FLOPs-bound 영역(예: 훈련 또는 대규모 배치 추론)에서는 통신이 없다고 가정하는 \(\text{Total FLOPs} / (N \cdot C) = 2 \cdot \text{param count} \cdot B / (N \cdot C)\) 하한을 사용할 수 있습니다.
Question: 이들 각각에 대해, 이것이 우리에게 주는 칩당 처리량은 얼마인가요(queries / chip 측면에서)? 중앙값 디코드 길이는 512 토큰이라고 가정할 수 있습니다.
이것은 중요한 질문입니다. 비용 / 토큰과 정확히 상관관계가 있기 때문입니다.
중앙값 디코드 길이에 대한 가정하에 처리량은 단지 \(B / (\text{per-step latency} \cdot \text{median steps} \cdot N) \approxeq 43 / (0.019 * 512 * N)\)입니다. 이는 대략 \((4.42 / N)\)QPS를 제공하므로\(N\)을 대입하면 다음과 같습니다:
| dtype | QPS / chip |
|---|---|
| bfloat16 | 0.27 |
| int8 | 0.55 |
| int4 | 1.11 |
| int8 | 0.55 |
| int4 | 1.11 |
이것은 순방향 패스의 작업 메모리(활성화 및 어텐션에 할당된 메모리)를 완전히 무시하므로 다소 낙관적입니다. Flash Attention을 사용하면 터무니없지는 않지만 현실적이지도 않습니다. 실제 숫자는 아마도 이것의 1/2 정도일 것입니다. 절대적인 최대 처리량을 위해서는 아마도 칩 수를 두 배 이상 늘리고 배치 크기도 크게 늘리고 싶을 것입니다.
Question: 위의 각 예에 대해 토폴로지를 두 배로 늘리면 피크 처리량은 어떻게 변할까요?
bfloat16에서 4x8 슬라이스를 사용하면 KV 캐시를 위해 372GB가 남게 되며, 이를 통해 배치 크기를 140까지 늘릴 수 있습니다. 그러면 단계 시간은 동일하게 유지되므로 14.39 / num_chips의 처리량을 갖게 되며, 이는 다음과 같습니다.
| dtype | QPS / chip |
|---|---|
| bfloat16 (on 4x8) | 0.44 |
| int8 (on 4x4) | 0.90 |
| int4 (on 2x4) | 1.80 |
더 늘리면 더 큰 승리를 얻을 수 있습니다! 큰 결론은 KV 캐시 크기에 의해 제한되는 경우 가장 작은 토폴로지가 모든 경우에 가장 성능이 좋은 토폴로지는 아니라는 것입니다.
Question: 이제 샤딩 문제에 대해 파고들어 봅시다. TPU v5e 4x8에서 bfloat16으로 서빙하고 싶다고 가정해 봅시다. generation 중 TPU v5e 4x8에서 모델에 어떤 샤딩을 사용하시겠습니까? comms bound가 되는 것을 피할 수 있을까요?
이전 섹션에서 논의했듯이 generation 중에는 실제로 샤딩을 위한 하나의 옵션만 있습니다: 모델 병렬 처리. comms bound가 되기 전까지 얼마나 많이 할 수 있을까요? 이전 섹션에서 논의했듯이 모델은 대략 다음과 같을 때 comms bound가 됩니다.
\(Y > \frac{F \cdot M_Y}{2200}\) \(Y > \frac{F \cdot M_Y}{2200}\)
LLaMA 3-70B의 경우 F = 28,672이므로, 2축의 모델 샤딩을 수행하면 대략 \(Y = 28672 \cdot 2 / 2200 = 26\)을 얻으므로 일반적으로 comms bound 없이 약 16개의 칩까지 확장할 수 있으며, 이는 4x4를 사용할 수 있게 해주지만 4x8은 아닙니다. 일반적으로 계산을 완벽하게 중첩하지 않으므로 이 추정치조차 지나치게 낙관적입니다.
Takeaway: 순수 모델 병렬 처리로는 4x8에서 실제로 서빙할 수 없습니다. 여기서 할 수 있는 최선은 4x2 또는 아마도 4x4입니다.
그러나 논의했듯이 배치 크기가 작을 때 모델은 FLOPs bound가 아니라 memory-bandwidth-bound이므로 처리량에 큰 해를 끼치지 않고 종종 더 많은 모델 병렬 처리를 수행할 수 있습니다. 우리는 이전에 이 값이 대략 $Y=F / (8\cdot B)$라고 말했으므로, 배치 크기 64를 수행하면 이론적으로 ICI-bound가 되기 전에 최대 Y = 28,672 / (8 * 64) = 56방향 모델 병렬 처리까지 갈 수 있습니다. 이를 확인하기 위해 단일 matmul에 대해 $T_\text{ici comms}$, $T_\text{hbm comms}$, $T_\text{math}$를 살펴볼 수 있습니다. 분명히 다음을 얻습니다:
4x8의 경우 $T_\text{ici comms}$ = (2 * 64 * 8192) / 9e10 = 11us, $T_\text{hbm comms}$ = (2 * 8192 * 28,672) / (32 * 8.1e11) = 18us, $T_\text{math}$ = (2 * 64 * 8192 * 28,672) / (32 * 1.97e14) = 4us가 되므로 이론적으로 우리는 여전히 HBM 대역폭 제한 상태이며, 이는 훌륭합니다! *참고로 4x4에서 4x8로 확장하는 것은 처리량 관점에서는 도움이 되지 않을 수 있지만 지연 시간은 줄어들 것입니다!
int8 및 int4 구성을 보면 순수 모델 병렬 처리로 수행할 수 있습니다. 따라서 우리는 양자화가 실제로 더 빠른 FLOPs 이상의 의미 있는 이점을 제공하는 지점에 도달했습니다: comms-bound가 되기 전에 더 큰 배치 크기를 사용할 수 있게 해줍니다. 따라서 이 이야기의 결론은 4x8에서 최대 처리량을 달성할 수는 없지만 int8 및 int4 구성의 경우 순수 모델 병렬 처리를 수행할 수 있다는 것입니다.
Tip: 유용한 모델 병렬 처리의 최대량은 \(d_{ff}\)와 모델을 샤딩하는 축의 수에 따라 다릅니다. 최대값은 일반적으로 모델 크기에 따라 8에서 32 사이입니다. 이 한계를 넘어서 확장하여 약간의 처리량 비용으로 지연 시간을 개선할 수 있습니다.
prefill은 훨씬 간단하기 때문에 여기서 대부분 무시했습니다. 몇 가지 개념을 합쳐서 엔드투엔드 그림을 생각해 봅시다.
Question: prefill 중 40% FLOPs 활용률을 달성한다고 가정해 봅시다. 16개의 TPU v5e 칩에서 길이 8192의 prefill은 얼마나 걸릴까요?
8k 토큰에서 우리는 확실히 compute bound이므로 FLOPs에 대해서만 생각하면 됩니다. 모델이 70e9 파라미터를 가지고 있으므로 각 순방향 패스는 2 * 70e9 * B FLOPs를 사용한다는 것을 알고 있습니다. 40% MFU(FLOPs 활용률)를 가정하면, 이는 약 2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s의 실행 시간을 제공합니다. 이전에 살펴보았던 숫자들과 비교하면 실제로는 꽤 깁니다!
Question: 중앙값 prefill 길이가 8192 토큰이고 중앙값 decode 길이가 4096 토큰이라고 가정해 봅시다. generate 배치 크기가 32라고 가정합니다. 평균적으로 단계당 얼마나 많은 시퀀스가 디코딩을 완료합니까? 평균적으로 매 단계마다 얼마나 많은 토큰이 KV 캐시에서 퇴출됩니까?
이것은 일종의 간단한 문제입니다. 중앙값 decode 길이가 4096 토큰이므로 시퀀스는 대략 1 / 4096 토큰마다 완료됩니다. 배치 크기가 32인 경우 단계당 32 / 4096개의 시퀀스가 퇴출됩니다. KV 캐시 길이는 대략 8192 + 4096이므로, 이는 단계당 32 * (8192 + 4096) / 4096 = 96개의 토큰이 퇴출됨을 의미합니다. 일반 공식은 $B * (P + G) / G$이며 여기서 $P$와 $G$는 prefill 및 generate 길이입니다.
Question: 중앙값 prefill 길이가 8192이고 중앙값 decode 길이가 512인 분산형 서빙을 수행한다고 가정해 봅시다. bfloat16에서 위에서 계산된 prefill 및 generate 지연 시간을 가정합니다. 둘 다 완전히 포화 상태로 유지하려면 prefill:generate 서버의 비율이 어떻게 되어야 할까요?
이것은 꽤 재미있는 질문입니다. $P$를 prefill 서버 수, $G$를 generate 서버 수라고 합시다. 따라서 일반적으로 말해서, 이것은 시퀀스를 P / prefill_latency 속도로 공급하고 B * G / (generate_latency * median_decode_length) 속도로 소비하는 파이프라인 문제입니다. 배치 크기 43(32라고 부릅시다)에서 prefill 단계당 910ms 및 decode 단계당 19ms를 계산했습니다. 따라서 P / 0.91 = 32 * G / (0.019 * 512) 또는 P = 3G, 즉 생성 서버보다 약 3배 더 많은 prefill 서버가 필요합니다!
잠시 LLaMA 70B를 고수하며 generation 중 다양한 배치 크기에 대한 지연 시간과 처리량을 실제로 살펴보겠습니다. 이전 섹션에서 PaLM 모델에 대해 보여주었듯이 이것은 처리량/지연 시간에 대한 파레토 프런티어를 제공합니다. MLP 블록에서 compute-bound를 유지하면서 사용할 수 있는 합리적인 한계이므로 16방향 텐서 병렬 처리를 가정해 보겠습니다. 여기서는 TPU v5e 4x4 토폴로지를 사용합니다. 슬라이더는 시퀀스 길이를 제어하므로 더 큰 KV 캐시의 효과를 볼 수 있습니다.
비용과 지연 시간의 원인을 파라미터 로딩 시간, KV 로딩 시간, FLOPs 시간으로 나누어 더 잘 이해할 수 있습니다. 빨간색 섹터는 MLP 블록에서 compute-bound가 될 것으로 예상되는 영역입니다.
이것은 꽤 많은 이야기를 해줍니다. 처음에 파라미터 로딩이 지연 시간의 대부분을 차지하다가 배치 크기가 충분히 커지면 FLOPs와 KV 로딩이 더 중요해지는 것을 볼 수 있습니다. 특히 2048보다 큰 모든 시퀀스 길이에서 FLOPs보다 KV 캐시 로딩에 더 많은 시간을 소비합니다! 따라서 배치 크기를 늘려 하드웨어 활용률을 개선할 수 있지만, 긴 컨텍스트 길이에서는 KV 로딩이 항상 총 단계 시간을 지배합니다.
Takeaway: LLaMA 3-70B의 경우 거의 모든 구성에서 강력하게 KV 캐시 메모리 대역폭 제한(및 HBM 제한) 상태이며, 이는 generation 처리량을 위해 KV 캐시 크기를 줄이는 것이 얼마나 중요한지 강조합니다. 또한 여기서 지연 시간/처리량 트레이드오프가 얼마나 극적인지 주목하세요.
다음은 이러한 루프라인을 계산하는 코드입니다:
import numpy as np
num_chips = 16 # we fix 16 as the amount of total model parallelism we do
param_size = 70e9 # int8 means 1 byte per param
sequence_length = 8192 # can vary this
hbm_bandwidth = 8.20E+11 # v5e
flops = 1.97E+14 # v5e
param_size = bytes_per_param * param_count
def kv_cache_size(bs):
return 2 * bs * 128 * 8 * 80
def min_topology(bytes):
return 2 ** np.ceil(np.log2(bytes / 16e9))
def get_max_batch_size(max_num_chips: int = 16):
# for num_chips in topo_sizes:
batch_sizes = np.arange(1, 1024, 4)
kv_sizes = kv_cache_size(sequence_length * batch_sizes)
num_chips = min_topology(kv_sizes + param_size)
max_idx = np.where(num_chips <= max_num_chips)[0][-1]
return max_idx
max_idx = get_max_batch_size(num_chips, sequence_length, param_size) # get the largest batch size that can fit
batch_sizes = np.arange(1, 512, 1)[:max_idx]
kv_sizes = kv_cache_size(sequence_length * batch_sizes)
kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)
param_comms_time = param_size / (num_chips * hbm_bandwidth)
param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])
flops_time = 2 * param_count * batch_sizes / (num_chips * flops) # roughly true in a 2ND sense
mlp_time = np.maximum(flops_time, param_comms_time)
attn_time = kv_comms_time # always bandwidth-bound for generate
latency = 1000 * (mlp_time + attn_time)
throughput = batch_sizes / (latency * num_chips)
우리가 지연 시간을 KV 로딩과 파라미터 로딩이라는 두 가지 소스로 매우 명시적으로 나누고, 지연 시간이 FLOPs 또는 comms 중 더 큰 것에 의해 제한되는 방식을 주목하세요.
다음은 몇 가지 풀이 문제입니다. 이 중 일부는 위에서 다룬 내용을 반복하지만 교육적으로 유용할 수 있습니다.
Question 1: LLaMA 3-405B에 대한 각 순방향 패스는 토큰당 몇 FLOPs를 사용하나요? FLOPs bound라고 가정할 때 TPU v5e의 N개 칩에서 단일 순방향 패스의 하한은 얼마인가요? comms bound라면 어떨까요? 모델이 단일 칩에 맞지 않는다는 사실은 무시하세요.
Question 2: int8 가중치와 int8 KV 캐시를 사용하여 BS240으로 LLaMA 3-8B를 서빙하고 싶다고 가정해 봅시다. (a) 모델 파라미터 (b) KV 캐시 (c) 피크 작업 활성화(대략)에 사용되는 바이트 수는 얼마인가요? 이것을 실행할 수 있는 가장 작은 토폴로지는 무엇인가요?
Question 3: TPU v5e에서 LLaMA 3-405B를 어떻게 서빙하시겠습니까? int8 가중치와 bfloat16 FLOPs를 가정합니다. 15ms / token의 확고한 제한이 있다고 가정할 때 달성할 수 있는 가장 높은 처리량 구성은 무엇인가요? 이론적 최소 단계 시간은 얼마인가요?