LLM Inference Optimization#
One of the perks of working at ThoughtWorks is being surrounded by very smart people who constantly keep you on your toes—and never let you stop learning. One such colleague recently triggered some thoughts about inference optimizations, which led me to explore and write a few things down to connect the dots.
BTW, the image above is just a random output from a diffusion model and has absolutely nothing to do with inference optimization. I added it just to reflect the state of my mind—with so many AI concepts floating around—everything starts to make sense… and then in the next moment, it doesn’t. 😉
LLM Inference#
LLM Inference is an end to end flow of executing a forward pass through a pre-trained large language model (LLM) to generate an output from a given input sequence. This involves processing the input into tokens, applying the model’s fixed parameters (learned during training) to transform input tokens into contextually relevant output tokens. Now this is way too simplified definition - bets lets continue to double click…
The End to End pipeline of LLM Inference can conceptualized as Input → Tokenization → Embedding → Prefill → Transformer Decoding / Generation → Output.
Will be too much to get into GPU Architecture - but just for my own curiosity, here is the brief on where on GPU the caching happens and where the compute happens.
GDDR (Graphics Double Data Rate) / VRAM memory is used as the global memory for the GPU to store model weights, activations, and KV caches for LLM inference.
SMs (Streaming Multiprocessors) is the compute units that execute matrix multiplications, attention, softmax, etc. It consists of specialized units for fast matrix multiplies (used heavily in transformer layers) called as Tensor Cores
Appreciating the Memory and Computation#
Lets appreciate how much computation and memory is required to run inference on a large language model (LLM) like LLaMA 2 7B.
We will also assume that each word in the input sequence is a single token i.e. PROMPT Why there is always so much to learn?
is tokenized into 8 tokens with ID’s: [2993, 1011, 318, 6717, 1121, 632, 311, 13779]
.
for 🦙 LLaMA 2 7B:#
\(N = 8\): Number of input tokens from
Why there is always so much to learn?
\(d_{model}\): Hidden size = 4096 i.e. its dimensionality size
Dimensionality of the embeddings used throughout the model
Every token is represented as a vector of 4096 dimensions inside the model
All linear layers (Q/K/V projections, MLPs) operate on this dimensionality
\(h\): Number of attention heads = 32
\(d_q = d_k = d_v\): Dimensions of each Query, Key and Value vector per head
This means:
Each of the 32 attention heads operates on 128-dimensional query, key, and value vectors.
All heads run in parallel, and their outputs are concatenated to form a vector of size 4096 again:
1. Token Embedding Lookup
Each token ID is mapped to a 4096-d vector via embedding table:
with \(V\) = vocab size (e.g. 32000+) for LLaMA 2 7B.
Shape:
X = token_embedding[input_ids] # Shape: [8, 4096]
2. LayerNorm - Pre-Attention
Performed typically before attention or MLP in transformer architecture. Why???
It stabilizes training (i.e. ensures consistent scaling) and helps avoid exploding/vanishing gradients.
Helps improve convergence i.e. allows better optimization with gradient descent.
LayerNorm 1 i.e. Pre-Attention: \(\text{LN}_1(X) \in \mathbb{R}^{8 \times 4096}\) is given by:
Or more compactly for the full matrix:
with:
\(X \in \mathbb{R}^{8 \times 4096}\) — the input matrix.
\(X_i \in \mathbb{R}^{4096}\) — the i-th row of \(X\), i.e., one token embedding.
\(\mu_i = \frac{1}{4096} \sum_{j=1}^{4096} X_{ij}\) — mean of the features for token \(i\).
\(\sigma_i^2 = \frac{1}{4096} \sum_{j=1}^{4096} (X_{ij} - \mu_i)^2\) — variance of the features.
\(\gamma \in \mathbb{R}^{4096}\) — Models learned scaling paramete… Learns how much to scale each feature
\(\beta \in \mathbb{R}^{4096}\) — Models learned shifting parameter…. Learns how much to shift each feature
\(\epsilon\) = small constant for numerical stability (e.g., \(1e-6\)) e.g. allows safe division by zero
Shape: Output shape remains the same i.e. \(X_{norm} = X = [8, 4096]\)
3. Linear Projections for Q/K/V vectors
In the attention mechanism, each input token is projected into three different vector spaces
Q = Query i.e. What this token is looking for
K = Key i.e. What other tokens contain
V = Value i.e. The actual content/information These determine how much attention each token should pay to the others
for 🦙 LLaMA 2 7B the pre-trained \(Q_{weights}\), \(K_{weights}\), \(V_{weights}\) would be#
\(Q_{weights}\) = \(W_Q \in \mathbb{R}^{4096 \times 4096}\)
Learned matrix to project each token’s embedding into a query vector.
Tells the model what the token is querying for in the context.
\(K_{weights}\) = \(W_K \in \mathbb{R}^{4096 \times 4096}\)
Learned matrix to project the token into a key vector.
Encodes what each token offers or matches when compared with queries.
\(V_{weights}\) = \(W_V \in \mathbb{R}^{4096 \times 4096}\)
Learned matrix to produce the value vector from each token.
Holds the actual content/information to be passed along after attention is calculated.
So…lets compute
\(QK_{vector}\), \(V_{vector}\)#
Where:
\(X_{\text{norm}} = \text{LN}_1(X) \in \mathbb{R}^{8 \times 4096}\)
\(W_Q, W_K, W_V \in \mathbb{R}^{4096 \times 4096}\) — learned projection weight matrices.
Shape: Output shape \(Q_{\text{vector}}, K_{\text{vector}}, V_{\text{vector}} \in \mathbb{R}^{8 \times 4096} = [8, 4096]\)
🦙 LLaMA 2 7B operates on 32 heads in each transformer layer (iteration). Each of the 32 attention heads operates on 128-dimensional query, key, and value vectors i.e. the Dimensionality 4096 is divided across the heads
All heads run in parallel, and their outputs are concatenated to form a vector of size 4096 again:
Head Splitting (reshape)#
Split the 4096-dimensional vectors into 32 heads of size 128:
Transpose for Multi-Head Attention#
Transpose the axes to move heads to the first dimension:
4. Multi-Head Attention
⚫️ 4a. Attention Score Matrix#
The attention score matrix measures how much each token should pay attention to every other token in the sequence.
Where:
\(Q_{\text{head}}, K_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 128}\)
For each head: compute dot product between queries and keys for all tokens
\(d_k = 128\) is the head dimension — we scale by \(\sqrt{d_k}\) to stabilize gradients
For each head and each token:
It compares the query vector of that token with the key vectors of all tokens in the sequence
Produces a similarity score — high score means: “I’m interested in this token”
i.e. For each head (32 total), compute dot products between all pairs of tokens (8 queries × 8 keys).
Shape:
⚫️ 4b. Apply Softmax to Attention Score Matrix#
Once we have raw attention scores, we apply softmax across the last dimension. This is required because a. Raw scores can be negative, unbounded, or skewed b. Softmax turns them into probabilities — i.e., values between 0 and 1 that sum to 1 across the sequence
This ensures that for each query token, the weights across all keys sum to 1.
⚫️ 4c. Compute Attention Output (Weighted Values)#
This step is the heart of the attention mechanism. It determines what each token should actually “see” or focus on — by combining information from other tokens based on attention scores.
Shape:
\(\text{Weights}_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 8}\)
\(V_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 128}\)
Dot product over the second dimension (sequence length) gives:
💡 The attention for a head [h] can be written as (combining 4a, 4b, 4c):
⚫️ 4d. Lastly, Combine Heads (Concatenate)#
Remember that each head captures different features of the input - for Llama 2 7B its like like a team of 32 specialists:
One looks at grammar
One looks at context
One tracks word position
and so on…
We will need to merge the heads back into one tensor per token a. final shape must match the model’s hidden dimension b. It enables the output projection layer to linearly mix the insights from all heads
💡 The the full multi-head attention can be written as (combining 4a, 4b, 4c and 4d):
where each \(\text{head}_h\) from 1 to 32
5. Residual Connection + LayerNorm (Post-Attention)
After the attention output, we add it back to the original input of the attention layer (residual connection) and then apply LayerNorm to stabilize and normalize the activations.
This is crucial for deep transformer models to:
Prevent vanishing/exploding gradients
Enable smoother training
Allow the model to learn modifications over the input, rather than starting from scratch
Residual Connection
Shape:
$X_{\text{input}} \in \mathbb{R}^{8 \times 4096}$
$\text{Output}_{\text{final}} \in \mathbb{R}^{8 \times 4096}$
So their sum is also: $\mathbb{R}^{8 \times 4096}$
Apply LayerNorm
After the residual addition, we apply LayerNorm which normalizes across the feature dimension, not across the batch.
This ensures that for each token, the 4096 features are:
Zero-centered (mean = 0)
Unit variance (std = 1), then rescaled and shifted using learned parameters $\gamma$ and $\beta$
⚠️ Important: LayerNorm parameters are learned, and different for each transformer block.
Final Output of Attention Block:
Shape:
\(\text{Output}_{\text{attn\_block}}\) = \(\mathbb{R}^{8 \times 4096}\) (batch of 8 tokens, each of hidden size 4096)
6. MLP Feedforward Block with SwiGLU Activation
After the attention block, we pass the \(\text{Output}_{\text{attn\_block}}\) through a feedforward neural network (MLP). This MLP consists of two linear layers with a SwiGLU activation in between. $\( Z = \text{Output}_{\text{attn\_block}} \cdot W_1 + b_1 \in \mathbb{R}^{8 \times 11008} \)$
where
\(W_1 \in \mathbb{R}^{4096 \times 11008}\) (model pre-trained weight matrix for the first linear layer)
\(b_1 \in \mathbb{R}^{11008}\), (model pre-trained bias vector for the first linear layer)
\(Z \in \mathbb{R}^{8 \times 11008}\).
Split \(Z\) into two halves Since SwiGLU needs two vectors \(Z_1, Z_2 \in \mathbb{R}^{5504}\), split \(Z\) along the feature dimension:
where
\(Z_1 \in \mathbb{R}^{8 \times 5504}\),
\(Z_2 \in \mathbb{R}^{8 \times 5504}\).
SwiGLU Activation
Split \(Z = [Z_1, Z_2]\), where each \(\in \mathbb{R}^{8 \times 5504}\)
\(\text{SiLU}(z) = z \cdot \sigma(z)\) where \(\sigma\) is sigmoid,
\(\odot\) denotes element-wise multiplication.
🚧 WIP: there is heck lot more to cover 🚧#
[⭕️TODO] Show the Cost of tokens within the inference pipeline for single and multi user (memory, FLOPS, latency, etc.)#
[⭕️TODO] Optimization Techniques#
Model Compression#
Quantization
Pruning
Distillation
Model Architecture (e.g. Transformer)#
System Optimization#
Memory Management
Batching
Efficient computation
Scheduling