Serving LLMs in Production - The Optimization
Post-training Model Compression
Common post-training model compression methods include pruning, quantization, and distillation. These methods effectively reduce model parameters, increase inference speed, and largely maintain model accuracy.
Pruning
Pruning refers to reducing model parameters by setting certain redundant parameter weights to 0, thereby reducing the model’s parameter count. Structured pruning removes entire neurons or filters, while unstructured pruning removes individual weights (setting weights to zero). Pruning effectively reduces model parameters (by over 90%) while maintaining model accuracy, making it a highly cost-effective compression method for large models.
A simple and intuitive pruning method involves setting a threshold; if a weight’s absolute value is below this threshold, it is set to 0. Alternatively, a binary mask of the same size and shape as the weights can represent which parameters need pruning. By multiplying the mask with the weights, we obtain the pruned parameters.
Quantization
Model quantization is a post-training model optimization technique that uses lower-precision data types (like int8) for model parameters and activations, instead of the 32-bit floating-point format (float32) typically used in training, to reduce computation and memory requirements. Integer data types also speed up matrix multiplications, further enhancing inference speed. There are various quantization methods, such as weight quantization, activation quantization, and gradient quantization.
The two most common quantization types are:
float32 -> float16
float32 -> int8
Quantizing from float32
to float16
is straightforward; it merely involves converting model parameters and activation data types from float32
to float16
. However, float32
to int8
quantization is more complex, as int8
can only represent integers between -128 and 127, while float32
has a much larger range. Therefore, affine quantization scheme is often used for float32
to int8
quantization.
For a specific float32
parameter range, \([\alpha, \beta]\), we can quantize it to int8
with the following formula:
where \(x\) is the quantized parameter, \(x_q\) is the original parameter, \(S\) is the scaling factor, and \(Z\) is the zero point, representing the int8
value corresponding to 0
in float32
. Quantized values \(x_q\) are calculated as follows:
Values outside \([\alpha, \beta]\) in float32
will be clipped to the nearest representable value. Therefore, for any floating-point number \(x\), it can be quantized to int8
using:
Quantization methods have many additional details and specific optimization techniques, but these methods are generally effective for models of all sizes.
Distillation
Model distillation trains a smaller model (student model) to approximate a larger model (teacher model), effectively transferring knowledge from the larger model to the smaller one and reducing the smaller model’s parameters. For LLMs, this approach is particularly useful since they contain billions of parameters and are inherently “capable” of various tasks.
One distillation method uses LLMs for labeling unlabeled data on a specific task, which can then train a smaller task-specific model. This method is straightforward but requires substantial unlabeled data and training time. Notably, many GenAI companies (like OpenAI) explicitly prohibit using their models’ generated data to train competitors’ models.
Another popular distillation method is Knowledge Distillation (e.g., DistilBERT), where a smaller model (student) is trained to approximate the output probability distribution of a larger model (teacher). This approach focuses only on the teacher model’s output distribution, not the specific values. However, since it requires the LLM’s full output probability distribution, it can only be performed on open-source LLMs.
Decoding
Decoding speed directly affects LLM inference speed, so several decoding optimization methods exist for LLMs. Traditional LLMs perform inference using autoregressive sampling (as shown below).
This structure introduces time dependencies in model inference.
Non-autoregressive (NAR) Decoding
The term “Autoregressive” refers to the use of previous time steps to predict the current output. Non-autoregressive decoding generates all outputs simultaneously without depending on prior outputs, effectively removing dependencies between tokens and allowing parallel generation. This method significantly reduces decoding time but may decrease model accuracy.
Speculative Decoding
Speculative Decoding, also known as assisted generation or Medusa, increases LLM inference throughput (~2-3x) with minimal accuracy loss. It requires one or more smaller models to assist LLM inference:
- Target Model: The primary LLM we use
- Small Draft (Approximation) Model: A lighter model primarily for speeding up inference
During inference, the Small Draft Model generates a draft output, which the Target Model refines. This approach speeds up inference by generating multiple tokens in a single LLM forward pass. However, it requires additional computational resources for the smaller model.
To analyze Speculative Decoding’s performance, suppose the worst case where the draft model’s output is entirely wrong, necessitating corrections by the target model each time. In this case, the inference count will not exceed that of autoregressive decoding. However, if the draft model generates satisfactory values, multiple draft tokens can be validated in a single target model pass, significantly increasing speed.
If an additional draft model isn’t available, the Speculative Decoding concept can still be applied by reusing previously generated tokens in the sequence as speculation for new tokens, known as N-gram assisted generation.
Batching
Batching is a vital optimization method in model inference. By processing multiple input data instances simultaneously, batching maximizes parallel computation, enhancing GPU utilization and increasing inference throughput.
Static Batching
Static batching refers to processing multiple input data instances simultaneously during inference. This approach reduces computation and speeds up inference but requires predefined input sizes. For inputs of different sizes, padding may be necessary.
A simple example is shown below:
Static batching waits for the “longest” request to finish before returning results. This can lead to extended waiting times for some requests, limiting GPU utilization. Moreover, with a small workload, static batching might cause queuing delays as it waits for enough requests to reach the batch size, further reducing GPU utilization.
Continuous Batching
To address the waiting issue in static batching, continuous batching returns completed requests first, continuing inference for other requests. A visual example is below:
However, a potential issue with this approach is the need to pre-allocate “sufficient” memory space to handle the longest request in each batch to prevent memory overflow. This can lead to memory waste, as the size of incoming requests is unknown in advance. This is why PagedAttention is used to optimize LLM memory management, which we’ll discuss later.
Memory Optimization
The attention mechanism is central to LLMs, enabling them to weigh input data parts differently and focus on relevant information for a task. Improving attention performance enhances inference speed while reducing computation. Since attention is memory-bound, memory management optimization is crucial.
LLM inference memory allocation breakdown:
- Model Parameters: About 65%-70%
- KV Cache: Stores the key and value of the autoregressive model, about 25%-30%
- Others: Activations, intermediate results, about 5%-10%
A simplified version of LLM memory usage during inference:
LLM inference consists of two phases:
- Prefill phase: Prompt token preloading and input computation
- Decode phase: Generation of each subsequent token
During the prefill phase, LLM precomputes inputs that remain constant in subsequent generations, leveraging GPU parallelism.
FlashAttention (GPU only)
Regardless of memory type, a general rule is that faster memory is more expensive and has less capacity. FlashAttention uses three types of memory (a sample configuration on an Nvidia A100 + CPU machine):
- CPU main memory (CPU DRAM): Largest capacity, slowest speed (12.8GB/s, >1TB)
- HBM (high bandwidth memory): Fastest speed, moderate capacity (1.5TB/s, 40GB)
- SRAM (on-chip memory): Fastest, smallest capacity (19TB/s, 20MB)
“Graphics memory” typically refers to HBM.
Traditional Transformer inference follows these steps, assuming \(K, V, Q \in R^{N \times d}\) matrices in HBM:
- Load \(Q\) and \(K\) by blocks from HBM, compute \(S = QK^T\) and store \(S\) in HBM.
- Read \(S\) from HBM, compute \(P = \text{softmax}(S)\), and store \(P\) in HBM.
- Load \(P\) from HBM, compute \(O = PV\), and store \(O\) in HBM.
- Return \(O\)
Here, \(K, V, Q\) are key, value, and query matrices, \(S\) is the score matrix, \(P\) is the softmax probability matrix, and \(O\) is the output matrix. The traditional Transformer loads all QKV matrices into HBM for matrix multiplication, which increases HBM read/write operations and reduces inference speed.
The intermediate storage of \(S\) between the first and second steps is unnecessary. If \(S\) is cached in SRAM, HBM read/write operations can be reduced. This approach, known as kernel fusion on GPUs, merges low-level computations to reduce communication overhead. However, SRAM’s limited capacity prevents it from storing the entire \(S\) matrix. FlashAttention addresses this issue with two key ideas:
- Tiling: Used in both forward and backward passes by chunking the NxN softmax/scores matrix into blocks, each small enough for SRAM.
- Recomputation: Used only in the backward pass. This approach stores outputs and softmax normalization statistics to recompute attention matrices, using incremental calculation to obtain precise attention outputs.
PagedAttention
According to our earlier analysis of LLM memory structures, the KV Cache occupies significant memory, making it a primary target for optimization.
PagedAttention, a vLLM optimization inspired by the “virtual memory with paging” concept in traditional OS memory management, divides the KV cache into non-contiguous, fixed-size pages. During inference, only the necessary pages are loaded as needed.
This approach eliminates the need to pre-allocate continuous memory space for the maximum context length and dynamically loads KV cache pages during inference to improve efficiency. However, this method does not guarantee perfect memory utilization; the last block after paging may be partially wasted (vLLM currently achieves less than 4% waste).
Parallelism
Parallel computing is a crucial method for speeding up model inference by dividing the computation into multiple parts for simultaneous execution. Common parallel computing methods include data parallelism, model parallelism, tensor parallelism, and pipeline parallelism. While parallelism is widely used in both model training and inference, Tensor and Pipeline parallelism are more common in inference due to the smaller computation load, more complex data pipeline, and need to handle production uncertainties.
Tensor Parallelism
Tensor parallelism uses multiple GPUs for distributed matrix computation. In LLM inference, tensor operations mainly refer to matrix multiplications of the Key, Value, and Query matrices in the attention mechanism and MLP layers. Matrix multiplications are split across GPUs to utilize multiple devices.
For example, with input vector \(X\), weight matrix \(A\), and output vector \(Y\), we can split the calculation across 3 GPUs as follows:
\[Y = X \cdot A\]is equivalent to
\[\begin{bmatrix} Y_1 \\ Y_2 \\ Y_3 \end{bmatrix} = X \cdot \begin{bmatrix} A_1 & A_2 & A_3 \end{bmatrix}\]where each \(Y_i = X \cdot A_i\) is computed on a separate GPU.
Pipeline Parallelism
Pipeline parallelism involves dividing the model vertically into chunks, each allocated to a different device (stage) for computation. During forwarding, each stage passes intermediate activations to the next stage for parallel computation.
Due to the need for sequential activation transfer and processing across pipeline stages, certain devices idle while awaiting the previous stage’s computation results, reducing pipeline parallelism efficiency, known as “pipeline bubbles.”
Microbatching can reduce pipeline bubbles by dividing input data into smaller batches and interleaving them across stages, reducing waiting times. However, this does not entirely eliminate pipeline bubbles.
Reference
Some of the pictures/links are from the following articles:
- Fast Inference from Transformers via Speculative Decoding
- Model Compression via Pruning
- LLM distillation demystified: a complete guide
- LLM Distillation Explained: Applications, Implementation & More
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- vLLM: A high-throughput and memory-efficient inference and serving engine for LLMs
- What it means to serve an LLM and which serving technology to choose from
- Mastering LLM Techniques: Inference Optimization
- GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
- Towards Efficient Generative Large Language Model Serving: A Survey from Algorithms to Systems