LLM Inference Optimization#

random

One of the perks of working at ThoughtWorks is being surrounded by very smart people who constantly keep you on your toes—and never let you stop learning. One such colleague recently triggered some thoughts about inference optimizations, which led me to explore and write a few things down to connect the dots.

BTW, the image above is just a random output from a diffusion model and has absolutely nothing to do with inference optimization. I added it just to reflect the state of my mind—with so many AI concepts floating around—everything starts to make sense… and then in the next moment, it doesn’t. 😉

LLM Inference#

LLM Inference is an end to end flow of executing a forward pass through a pre-trained large language model (LLM) to generate an output from a given input sequence. This involves processing the input into tokens, applying the model’s fixed parameters (learned during training) to transform input tokens into contextually relevant output tokens. Now this is way too simplified definition - bets lets continue to double click…

inference

The End to End pipeline of LLM Inference can conceptualized as Input → Tokenization → Embedding → Prefill → Transformer Decoding / Generation → Output.

pipeline

Will be too much to get into GPU Architecture - but just for my own curiosity, here is the brief on where on GPU the caching happens and where the compute happens.

GDDR (Graphics Double Data Rate) / VRAM memory is used as the global memory for the GPU to store model weights, activations, and KV caches for LLM inference.

SMs (Streaming Multiprocessors) is the compute units that execute matrix multiplications, attention, softmax, etc. It consists of specialized units for fast matrix multiplies (used heavily in transformer layers) called as Tensor Cores

gpu

ref: glossary & basic

Appreciating the Memory and Computation#

Lets appreciate how much computation and memory is required to run inference on a large language model (LLM) like LLaMA 2 7B.

We will also assume that each word in the input sequence is a single token i.e. PROMPT Why there is always so much to learn? is tokenized into 8 tokens with ID’s: [2993, 1011, 318, 6717, 1121, 632, 311, 13779].

for 🦙 LLaMA 2 7B:#

  • \(N = 8\): Number of input tokens from Why there is always so much to learn?

  • \(d_{model}\): Hidden size = 4096 i.e. its dimensionality size

    • Dimensionality of the embeddings used throughout the model

    • Every token is represented as a vector of 4096 dimensions inside the model

    • All linear layers (Q/K/V projections, MLPs) operate on this dimensionality

  • \(h\): Number of attention heads = 32

  • \(d_q = d_k = d_v\): Dimensions of each Query, Key and Value vector per head

\[ d_q = d_k = d_v = \frac{d_{\text{model}}}{h} = \frac{4096}{32} = 128 \]

This means:

  • Each of the 32 attention heads operates on 128-dimensional query, key, and value vectors.

  • All heads run in parallel, and their outputs are concatenated to form a vector of size 4096 again:

\[ \text{Concat}(32 \times 128) = 4096 \]

transformer

1. Token Embedding Lookup

Each token ID is mapped to a 4096-d vector via embedding table:

\[ X = E_{token}[t_1, t_2, ..., t_8] \quad \text{where } E_{token} \in \mathbb{R}^{V \times d_{model}} \]

with \(V\) = vocab size (e.g. 32000+) for LLaMA 2 7B.

Shape:

\[ X \in \mathbb{R}^{8 \times 4096} \]
X = token_embedding[input_ids]  # Shape: [8, 4096]

embed

2. LayerNorm - Pre-Attention

Performed typically before attention or MLP in transformer architecture. Why???

  • It stabilizes training (i.e. ensures consistent scaling) and helps avoid exploding/vanishing gradients.

  • Helps improve convergence i.e. allows better optimization with gradient descent.

LayerNorm 1 i.e. Pre-Attention: \(\text{LN}_1(X) \in \mathbb{R}^{8 \times 4096}\) is given by:

\[ \text{LN}_1(X)_i = \frac{X_i - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} \cdot \gamma + \beta, \quad \text{for } i = 1, \dots, 8 \]

Or more compactly for the full matrix:

\[ \text{LN}_1(X) = \frac{X - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \]

with:

  • \(X \in \mathbb{R}^{8 \times 4096}\) — the input matrix.

  • \(X_i \in \mathbb{R}^{4096}\) — the i-th row of \(X\), i.e., one token embedding.

  • \(\mu_i = \frac{1}{4096} \sum_{j=1}^{4096} X_{ij}\) — mean of the features for token \(i\).

  • \(\sigma_i^2 = \frac{1}{4096} \sum_{j=1}^{4096} (X_{ij} - \mu_i)^2\) — variance of the features.

  • \(\gamma \in \mathbb{R}^{4096}\) — Models learned scaling paramete… Learns how much to scale each feature

  • \(\beta \in \mathbb{R}^{4096}\) — Models learned shifting parameter…. Learns how much to shift each feature

  • \(\epsilon\) = small constant for numerical stability (e.g., \(1e-6\)) e.g. allows safe division by zero

Shape: Output shape remains the same i.e. \(X_{norm} = X = [8, 4096]\)

3. Linear Projections for Q/K/V vectors

In the attention mechanism, each input token is projected into three different vector spaces

  • Q = Query i.e. What this token is looking for

  • K = Key i.e. What other tokens contain

  • V = Value i.e. The actual content/information These determine how much attention each token should pay to the others

for 🦙 LLaMA 2 7B the pre-trained \(Q_{weights}\), \(K_{weights}\), \(V_{weights}\) would be#

\(Q_{weights}\) = \(W_Q \in \mathbb{R}^{4096 \times 4096}\)

  • Learned matrix to project each token’s embedding into a query vector.

  • Tells the model what the token is querying for in the context.

\(K_{weights}\) = \(W_K \in \mathbb{R}^{4096 \times 4096}\)

  • Learned matrix to project the token into a key vector.

  • Encodes what each token offers or matches when compared with queries.

\(V_{weights}\) = \(W_V \in \mathbb{R}^{4096 \times 4096}\)

  • Learned matrix to produce the value vector from each token.

  • Holds the actual content/information to be passed along after attention is calculated.

So…lets compute

\(QK_{vector}\), \(V_{vector}\)#

\[ Q_{\text{vector}} = X_{\text{norm}} \cdot W_Q \]
\[ K_{\text{vector}} = X_{\text{norm}} \cdot W_K \]
\[ V_{\text{vector}} = X_{\text{norm}} \cdot W_V \]

Where:

  • \(X_{\text{norm}} = \text{LN}_1(X) \in \mathbb{R}^{8 \times 4096}\)

  • \(W_Q, W_K, W_V \in \mathbb{R}^{4096 \times 4096}\) — learned projection weight matrices.

Shape: Output shape \(Q_{\text{vector}}, K_{\text{vector}}, V_{\text{vector}} \in \mathbb{R}^{8 \times 4096} = [8, 4096]\)

🦙 LLaMA 2 7B operates on 32 heads in each transformer layer (iteration). Each of the 32 attention heads operates on 128-dimensional query, key, and value vectors i.e. the Dimensionality 4096 is divided across the heads

\[ d_q = d_k = d_v = \frac{d_{\text{model}}}{h} = \frac{4096}{32} = 128 \]

All heads run in parallel, and their outputs are concatenated to form a vector of size 4096 again:

\[ \text{Concat}(32 \times 128) = 4096 \]

Head Splitting (reshape)#

Split the 4096-dimensional vectors into 32 heads of size 128:

\[ Q_{\text{split}} = \text{reshape}(Q_{\text{vector}}, [8, 32, 128]) \quad \in \mathbb{R}^{8 \times 32 \times 128} \]
\[ K_{\text{split}} = \text{reshape}(K_{\text{vector}}, [8, 32, 128]) \quad \in \mathbb{R}^{8 \times 32 \times 128} \]
\[ V_{\text{split}} = \text{reshape}(V_{\text{vector}}, [8, 32, 128]) \quad \in \mathbb{R}^{8 \times 32 \times 128} \]

Transpose for Multi-Head Attention#

Transpose the axes to move heads to the first dimension:

\[ Q_{\text{head}} = \text{transpose}(Q_{\text{split}}, (1, 0, 2)) \quad \in \mathbb{R}^{32 \times 8 \times 128} \]
\[ K_{\text{head}} = \text{transpose}(K_{\text{split}}, (1, 0, 2)) \quad \in \mathbb{R}^{32 \times 8 \times 128} \]
\[ V_{\text{head}} = \text{transpose}(V_{\text{split}}, (1, 0, 2)) \quad \in \mathbb{R}^{32 \times 8 \times 128} \]

4. Multi-Head Attention

⚫️ 4a. Attention Score Matrix#

The attention score matrix measures how much each token should pay attention to every other token in the sequence.

\[ \text{Scores}_{\text{head}} = \frac{Q_{\text{head}} \cdot K_{\text{head}}^\top}{\sqrt{d_k}} \]

Where:

  • \(Q_{\text{head}}, K_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 128}\)

  • For each head: compute dot product between queries and keys for all tokens

  • \(d_k = 128\) is the head dimension — we scale by \(\sqrt{d_k}\) to stabilize gradients

  • For each head and each token:

    • It compares the query vector of that token with the key vectors of all tokens in the sequence

    • Produces a similarity score — high score means: “I’m interested in this token”

i.e. For each head (32 total), compute dot products between all pairs of tokens (8 queries × 8 keys).

\[ \text{Scores}_{\text{head}}[h] = \frac{Q_{\text{head}}[h] \cdot K_{\text{head}}[h]^\top}{\sqrt{128}} \quad \in \mathbb{R}^{8 \times 8} \]

Shape:

\[ \text{Scores}_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 8} \]

⚫️ 4b. Apply Softmax to Attention Score Matrix#

Once we have raw attention scores, we apply softmax across the last dimension. This is required because a. Raw scores can be negative, unbounded, or skewed b. Softmax turns them into probabilities — i.e., values between 0 and 1 that sum to 1 across the sequence

\[ \text{Weights}_{\text{head}} = \text{softmax}(\text{Scores}_{\text{head}}, \text{dim} = -1) \]

This ensures that for each query token, the weights across all keys sum to 1.

⚫️ 4c. Compute Attention Output (Weighted Values)#

This step is the heart of the attention mechanism. It determines what each token should actually “see” or focus on — by combining information from other tokens based on attention scores.

\[ \text{Output}_{\text{head}} = \text{Weights}_{\text{head}} \cdot V_{\text{head}} \]

Shape:

  • \(\text{Weights}_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 8}\)

  • \(V_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 128}\)

Dot product over the second dimension (sequence length) gives:

\[ \text{Output}_{\text{head}} \in \mathbb{R}^{32 \times 8 \times 128} \]

💡 The attention for a head [h] can be written as (combining 4a, 4b, 4c):

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q_h K_h^\top}{\sqrt{d_k}}\right) V_h \in \mathbb{R}^{8 \times 128} \]

⚫️ 4d. Lastly, Combine Heads (Concatenate)#

Remember that each head captures different features of the input - for Llama 2 7B its like like a team of 32 specialists:

  • One looks at grammar

  • One looks at context

  • One tracks word position

  • and so on…

We will need to merge the heads back into one tensor per token a. final shape must match the model’s hidden dimension b. It enables the output projection layer to linearly mix the insights from all heads

\[ \text{Output}_{\text{concat}} = \text{transpose}(\text{Output}_{\text{head}}, (1, 0, 2)) \rightarrow \mathbb{R}^{8 \times 32 \times 128} \]
\[ \text{Output}_{\text{final}} = \text{reshape}(\text{Output}_{\text{concat}}, [8, 4096]) \]

💡 The the full multi-head attention can be written as (combining 4a, 4b, 4c and 4d):

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_{32}) \in \mathbb{R}^{8 \times 4096} \]

where each \(\text{head}_h\) from 1 to 32

\[ \text{head}_h = \text{Attention}(Q, K, V)_h = \text{softmax}\left(\frac{Q_h K_h^\top}{\sqrt{d_k}}\right) V_h \in \mathbb{R}^{8 \times 128} \]

5. Residual Connection + LayerNorm (Post-Attention)

After the attention output, we add it back to the original input of the attention layer (residual connection) and then apply LayerNorm to stabilize and normalize the activations.

This is crucial for deep transformer models to:

  • Prevent vanishing/exploding gradients

  • Enable smoother training

  • Allow the model to learn modifications over the input, rather than starting from scratch

Residual Connection

\[ \text{Residual}_{\text{attn}} = X_{\text{input}} + \text{Attention}(Q, K, V) \]

Shape:

  • $X_{\text{input}} \in \mathbb{R}^{8 \times 4096}$

  • $\text{Output}_{\text{final}} \in \mathbb{R}^{8 \times 4096}$

  • So their sum is also: $\mathbb{R}^{8 \times 4096}$

Apply LayerNorm

After the residual addition, we apply LayerNorm which normalizes across the feature dimension, not across the batch.

\[ \text{Norm}_{\text{attn}} = \text{LayerNorm}(\text{Residual}_{\text{attn}}) \]

This ensures that for each token, the 4096 features are:

  • Zero-centered (mean = 0)

  • Unit variance (std = 1), then rescaled and shifted using learned parameters $\gamma$ and $\beta$

⚠️ Important: LayerNorm parameters are learned, and different for each transformer block.

Final Output of Attention Block:

\[ \text{Output}_{\text{attn\_block}} = \text{LayerNorm}(X_{\text{input}} + \text{Attention}(Q, K, V)) \]

Shape:

\(\text{Output}_{\text{attn\_block}}\) = \(\mathbb{R}^{8 \times 4096}\) (batch of 8 tokens, each of hidden size 4096)

6. MLP Feedforward Block with SwiGLU Activation

After the attention block, we pass the \(\text{Output}_{\text{attn\_block}}\) through a feedforward neural network (MLP). This MLP consists of two linear layers with a SwiGLU activation in between. $\( Z = \text{Output}_{\text{attn\_block}} \cdot W_1 + b_1 \in \mathbb{R}^{8 \times 11008} \)$

where

  • \(W_1 \in \mathbb{R}^{4096 \times 11008}\) (model pre-trained weight matrix for the first linear layer)

  • \(b_1 \in \mathbb{R}^{11008}\), (model pre-trained bias vector for the first linear layer)

  • \(Z \in \mathbb{R}^{8 \times 11008}\).

Split \(Z\) into two halves Since SwiGLU needs two vectors \(Z_1, Z_2 \in \mathbb{R}^{5504}\), split \(Z\) along the feature dimension:

\[ Z = [Z_1, Z_2] \]

where

  • \(Z_1 \in \mathbb{R}^{8 \times 5504}\),

  • \(Z_2 \in \mathbb{R}^{8 \times 5504}\).

SwiGLU Activation

Split \(Z = [Z_1, Z_2]\), where each \(\in \mathbb{R}^{8 \times 5504}\)

\[ H = \text{SwiGLU}(Z) = \text{SiLU}(Z_1) \odot Z_2 \in \mathbb{R}^{8 \times 5504} \]
  • \(\text{SiLU}(z) = z \cdot \sigma(z)\) where \(\sigma\) is sigmoid,

  • \(\odot\) denotes element-wise multiplication.

Step 3: Project back to hidden size#

\[ \text{mlp\_output} = H \cdot W_2 + b_2 \in \mathbb{R}^{8 \times 4096} \]

where

  • \(W_2 \in \mathbb{R}^{5504 \times 4096}\) (model pre-trained weight matrix for the second linear layer)

  • \(b_2 \in \mathbb{R}^{4096}\) (model pre-trained bias vector for the second linear layer)

  • Output shape: \(\mathbb{R}^{8 \times 4096}\).

7. Residual (Post-MLP) - Transformer Block Output

\[ X_{\text{final}} = \text{Residual}_{\text{attn}} + \text{mlp\_output} \in \mathbb{R}^{8 \times 4096} \]

This is the final output of the transformer block. \(X_{\text{final}}\) contains a transformed version of the input \(X\) across across 4096 dimensions. For Llama 2 7B, this is repeated 32 times (i.e. 32 transformer blocks) before predicting the next token.

# Token Embedding  
X0 = token_embedding[input_ids]  # [8, 4096]

# 32 Transformer Blocks
X1 = transformer_block_1(X0)     # [8, 4096] - Input is used in its original form
X2 = transformer_block_2(X1)     # ...
...
X32 = transformer_block_32(X31) # Final hidden state

8. Final Token Prediction

  • There are 8 tokens in the input sequence (sequence length = 8),

  • Each token is represented by a 4096-dimensional vector (embedding size = 4096),

  • Residual_attn is the output of the self-attention mechanism plus its residual connection,

  • mlp_output is the output of the feed-forward network (MLP) in the Transformer block.

The model takes X_final (output of 32nd Transformer block) and passes it through a final linear layer (also called the output head) that maps the 4096-dimensional vector to a vector of vocabulary size \(\text{vocab\_size}\)

\[ \text{logits} = X_{\text{final}} \cdot E_{\text{token}}^\top \in \mathbb{R}^{\text{vocab\_size}} \]
  • \(E_{\text{token}}^\top \in \mathbb{R}^{4096 \times \text{vocab\_size}}\) is the transposed token embedding matrix (shared or separate).

  • So for each of the 8 tokens, you get a probability distribution over the entire vocabulary.

Focus on the Last Token: Since we are predicting the next token, the model usually focuses only on the last position in the sequence (i.e., row 8 of the matrix if indexing from 1).

\[ \text{logits}^{(8)} = X_{\text{final}}^{(8)} \cdot E_{\text{token}}^\top \in \mathbb{R}^{\text{vocab\_size}} \]

The logits vector is passed through a softmax to convert it into a probability distribution over the vocabulary:

\[ P(\text{token}_{9}) = \frac{e^{\text{logits}^{(8)}_i}}{\sum_{j=1}^{V} e^{\text{logits}^{(8)}_j}} \]

The index \(i\) with the highest probability corresponds to the predicted next token.

🚧 WIP: there is heck lot more to cover 🚧#

[⭕️TODO] Show the Cost of tokens within the inference pipeline for single and multi user (memory, FLOPS, latency, etc.)#

[⭕️TODO] Optimization Techniques#

Model Compression#

  • Quantization

  • Pruning

  • Distillation

Model Architecture (e.g. Transformer)#

System Optimization#

  • Memory Management

  • Batching

  • Efficient computation

  • Scheduling