1 KV cache
During the inference phase, a common strategy to accelerate Transformer models is to use KV cache. A typical large model's generative inference consists of two stages:
- prefill phase: input a sequence of prompt, and generate key cache and value cache for each transformer layer.
- decode phase: using and updating the KV cache to generate tokens one by one, where each newly generated token depends on all previously generated tokens.
1.1 Prefill phase
Suppose the input of the -th layer is , and the key, value, query and output of self-attention are . Then we can derive the key cache and value cache by
The left operations are
1.2 Decode phase
Given the vector representation of the current generated token at the -th transformer layer, and the computation is mainly divided into two paces (update KV cache and compute the output)
- updating KV cache
- calculate the output
2 Additional GPU memory footprint
Suppose the input sequence length is , the output length is , storing KV cache in float16 results in a peak memory footprint of . Here, the first 2 accounts for KV cache, and the second 2 denotes float16 occupying 2 bytes per element.
Taking GPT3 as an example. The GPU memory footprint of model parameter is 350GB.
Suppose the batch , the input sequence length , output length , then store KV cache takes , which takes approximately 0.5 times the memory footprint of model parameters.
3 Computation saved
For , to calculate needs FLOPs. We also approximate it by , since usually is too large. (Another interpretation here is that if we consider the bias part in MLP, it is exactly .) For ease of interpretation, we eliminate the part here.
We do generate one token as the example.
3.1 Computation overhead in Self-attention
3.1.1 without KV cache
(the input dimension is )
- Compute , cost
- , cost
- , cost
- linear mapping after attention: , cost
3.1.2 with KV cache
(the input comes to , and past words in KV cache)
- Compute , cost
- , cost
- , cost
- linear mapping after attention: , cost
3.2 Computation overhead in MLP
3.2.1 without KV cache
- the first linear layer: , cost
- the second linear layer: , cost
3.2.2 with KV cache
- the first linear layer: , cost
- the second linear layer: , cost
3.3 Computation overhead in logits (serves to map to vocab size)
- without KV cache: , cost
- with KV cache: , cost
3.4 Summary
ㅤ | Without KV cache | with KV cache |
self-attention | ||
MLP | ||
logits |