Archive notes for computing R(m) using adjacent pairs#
I start with this approach but later realize that Qwen3 uses the rotate half pattern
Computing R(m)#
head dim (d = 8) (even)
number of RoPE pairs \((= d/2 = 4)\)
Will assume 6 input tokens ⇒ positions \( (m \in {0,1,2,3,4,5}) \)
base \((\theta = 10000)\) (usual init for theta)
Any head vector (query or key) for one token would look like (for given d):
⚫️ a) Pair Creation: RoPE pairs adjacent coordinates (same for Q and K):#
⚠️ Note: Qwen3 HF code: pairs are (0, D/2), (1, D/2+1), (2, D/2+2)… i.e. it uses classic rotate_half pattern… but for keeping it simple, I will use the adjacent pairs
pair 0: \((q_0, q_1)\)
pair 1: \((q_2, q_3)\)
pair 2: \((q_4, q_5)\)
pair 3: \((q_6, q_7)\)
⚫️ b) Compute \( \omega \) (omega) for each pair i i.e. a per-pair frequency (a constant for pair i, same for all tokens)#
With
So for each pair i = (0, 1, 2, 3):
\((\omega_0 = 1)\)
\((\omega_1 = 0.1)\)
\((\omega_2 = 0.01)\)
\((\omega_3 = 0.001)\)
import torch
theta = 10000.0
d = 8
i = torch.arange(0, d//2, dtype=torch.float32) # [0,1,2,3]
omega = theta ** (-2*i/d)
print("omega:", omega) # [1.0, 0.1, 0.01, 0.001]
omega: tensor([1.00, 0.10, 0.01, 0.00])
⚫️ c) Calculate rotation angle \(\phi_m\) (phi) per input token (denoted by position m), per-pair i#
so 6 input tokens ⇒ positions \( (m \in {0,1,2,3,4,5}) \), we get 6 sets of 4 angles.
import torch
omega = torch.tensor([1.0, 0.1, 0.01, 0.001])
for m in range(6):
phi = m * omega
print(f"m={m} --> ϕm={[float(f'{x:.2f}') for x in phi.tolist()]}")
m=0 --> ϕm=[0.0, 0.0, 0.0, 0.0]
m=1 --> ϕm=[1.0, 0.1, 0.01, 0.0]
m=2 --> ϕm=[2.0, 0.2, 0.02, 0.0]
m=3 --> ϕm=[3.0, 0.3, 0.03, 0.0]
m=4 --> ϕm=[4.0, 0.4, 0.04, 0.0]
m=5 --> ϕm=[5.0, 0.5, 0.05, 0.01]
⚫️ d) Build a 2x2 rotation block for one pair in \(q_{(m)}\) and angle \(\phi_m\)#
one token at position (m) has a query vector:
RoPE produces the rotated query:
Where $d=8 i.e. 4 pairs
pair 0: \((q^{(m)}_0, q^{(m)}_1)\) with angle \(\phi_{m,0}\)
pair 1: \((q^{(m)}_2, q^{(m)}_3)\) with angle \(\phi_{m,1}\)
pair 2: \((q^{(m)}_4, q^{(m)}_5)\) with angle \(\phi_{m,2}\)
pair 3: \((q^{(m)}_6, q^{(m)}_7)\) with angle \(\phi_{m,3}\)
Pair 0 (indices 0,1)
Pair 1 (indices 2,3)
Pair 2 (indices 4,5)
Pair 3 (indices 6,7)
Derivative w.r.t. the query components (pairwise Jacobian)
For pair 0, the Jacobian of \([\tilde{q}_0,\tilde{q}_1]\) w.r.t \([q_0,q_1]\) is:
Same form for pairs (2,3), (4,5), (6,7) using their angles.
Derivative w.r.t the angle \(\phi_{m,0}\) (example for pair 0)#
e.g. Rotate \(q\) for single token at position m=5
import torch
def rope_omega(d=8, theta=10000.0):
i = torch.arange(0, d//2, dtype=torch.float32) # [0,1,2,3]
return theta ** (-2*i/d)
def rope_rotate_q(q, m, theta=10000.0):
# q: shape (8,)
d = q.shape[-1]
omega = rope_omega(d, theta).to(q.device)
phi = m * omega
cos, sin = torch.cos(phi), torch.sin(phi)
q_tilde = q.clone()
# pair 0 (0,1)
q_tilde[0] = q[0]*cos[0] - q[1]*sin[0]
q_tilde[1] = q[0]*sin[0] + q[1]*cos[0]
# pair 1 (2,3)
q_tilde[2] = q[2]*cos[1] - q[3]*sin[1]
q_tilde[3] = q[2]*sin[1] + q[3]*cos[1]
# pair 2 (4,5)
q_tilde[4] = q[4]*cos[2] - q[5]*sin[2]
q_tilde[5] = q[4]*sin[2] + q[5]*cos[2]
# pair 3 (6,7)
q_tilde[6] = q[6]*cos[3] - q[7]*sin[3]
q_tilde[7] = q[6]*sin[3] + q[7]*cos[3]
return q_tilde, phi
q = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
q_tilde, phi = rope_rotate_q(q, m=5)
print("phi:", phi)
print("q:", q)
print("q_tilde:", q_tilde)
phi: tensor([5.00, 0.50, 0.05, 0.01])
q: tensor([1., 2., 3., 4., 5., 6., 7., 8.])
q_tilde: tensor([ 2.20, -0.39, 0.72, 4.95, 4.69, 6.24, 6.96, 8.03])
⚫️ e) Lastly we formulate the full \(R(m)\) matrix for (d=8) i.e. 8x8#
Define \(R(m)\in\mathbb{R}^{8\times 8}\) such that:
\(R(m)\) is block-diagonal with four 2×2 blocks:
rows/cols 0..1 use \(\phi_{m,0}\)
rows/cols 2..3 use \(\phi_{m,1}\)
rows/cols 4..5 use \(\phi_{m,2}\)
rows/cols 6..7 use \(\phi_{m,3}\)
So explicitly, \(R(m)\) is:
Derivative w.r.t q (full)
Since \(\tilde{q}^{(m)} = R(m) \cdot q^{(m)}\):
⚠️ For a given token at m=5, prove that the matrix form (geometric view) \(R(m) \cdot q^{(m)}\) IS EQUAL to pairwise form (manual each pair rotation) \(\tilde{q}^{(m)}\). Nothing fancy, its just to understand how the Matrix form is a glorified version of individual pairwise form…
import torch
def build_Rm(m, d=8, theta=10000.0):
omega = rope_omega(d, theta)
phi = m * omega
cos, sin = torch.cos(phi), torch.sin(phi)
Rm = torch.zeros(d, d)
# pair 0
Rm[0,0], Rm[0,1] = cos[0], -sin[0]
Rm[1,0], Rm[1,1] = sin[0], cos[0]
# pair 1
Rm[2,2], Rm[2,3] = cos[1], -sin[1]
Rm[3,2], Rm[3,3] = sin[1], cos[1]
# pair 2
Rm[4,4], Rm[4,5] = cos[2], -sin[2]
Rm[5,4], Rm[5,5] = sin[2], cos[2]
# pair 3
Rm[6,6], Rm[6,7] = cos[3], -sin[3]
Rm[7,6], Rm[7,7] = sin[3], cos[3]
return Rm, phi
# For a single token at position m=5
q = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
Rm, phi = build_Rm(m=5)
q_tilde_mat = Rm @ q
q_tilde_pair, _ = rope_rotate_q(q, m=5)
print("phi:", phi)
print("q:", q)
print("R(m):", Rm)
print("q_tilde (matrix form):", q_tilde_mat)
print("q_tilde (pairwise form):", q_tilde_pair)
print("Difference:", (q_tilde_mat - q_tilde_pair).abs().max().item())
phi: tensor([5.00, 0.50, 0.05, 0.01])
q: tensor([1., 2., 3., 4., 5., 6., 7., 8.])
R(m): tensor([[ 0.28, 0.96, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[-0.96, 0.28, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[ 0.00, 0.00, 0.88, -0.48, 0.00, 0.00, 0.00, 0.00],
[ 0.00, 0.00, 0.48, 0.88, 0.00, 0.00, 0.00, 0.00],
[ 0.00, 0.00, 0.00, 0.00, 1.00, -0.05, 0.00, 0.00],
[ 0.00, 0.00, 0.00, 0.00, 0.05, 1.00, 0.00, 0.00],
[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, -0.00],
[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00]])
q_tilde (matrix form): tensor([ 2.20, -0.39, 0.72, 4.95, 4.69, 6.24, 6.96, 8.03])
q_tilde (pairwise form): tensor([ 2.20, -0.39, 0.72, 4.95, 4.69, 6.24, 6.96, 8.03])
Difference: 0.0
============== Qwen3's RoPE function ==============
compute_rope_params_adjacent(...)
Precomputes cos and sin tables for all positions up to context_length.
For each pair i ∈ [0,…,d/2−1], computes frequency ωi
For each position m ∈ [0,…,context_length−1], compute angle ϕm,i=m⋅ωi
stores cos(ϕm,i),sin(ϕm,i)
import torch
def compute_rope_params_adjacent(
head_dim,
theta_base = 10_000,
context_length = 4096,
dtype=torch.float32,
):
assert head_dim % 2 == 0, "head_dim must be even"
# Pair index i = 0..(head_dim//2 - 1)
i = torch.arange(0, head_dim // 2, dtype=dtype)
# Frequencies omega[i] = theta_base^(-2i/head_dim)
omega = theta_base ** (-2.0 * i / head_dim) # (head_dim//2,)
# Positions m = 0..context_length-1
positions = torch.arange(context_length, dtype=dtype) # (context_length,)
# Angles phi[m, i] = m * omega[i]
angles = positions[:, None] * omega[None, :] # (context_length, head_dim//2)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin, angles
cos, sin, angles = compute_rope_params_adjacent(
head_dim=8,
theta_base=10_000.0,
context_length=6,
)
print("cos:", cos)
print("sin:", sin)
print("angles:", angles)
cos: tensor([[ 1.0000, 1.0000, 1.0000, 1.0000],
[ 0.5403, 0.9950, 0.9999, 1.0000],
[-0.4161, 0.9801, 0.9998, 1.0000],
[-0.9900, 0.9553, 0.9996, 1.0000],
[-0.6536, 0.9211, 0.9992, 1.0000],
[ 0.2837, 0.8776, 0.9988, 1.0000]])
sin: tensor([[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.8415, 0.0998, 0.0100, 0.0010],
[ 0.9093, 0.1987, 0.0200, 0.0020],
[ 0.1411, 0.2955, 0.0300, 0.0030],
[-0.7568, 0.3894, 0.0400, 0.0040],
[-0.9589, 0.4794, 0.0500, 0.0050]])
angles: tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
[2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
[3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03],
[4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03],
[5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03]])
apply_rope_adjacent(x, cos, sin, offset)
Applies RoPE to Q/K tensor shaped (batch, num of heads,num of tokens, head_dim) i.e. B, H, T, D.
import torch
def apply_rope_adjacent(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, offset: int = 0):
BATCH, NUM_HEADS, NUM_TOKENS, D = x.shape
assert D % 2 == 0, "head_dim must be even"
assert cos.shape[1] == D // 2 and sin.shape[1] == D // 2, "cos/sin must be (context_length, D//2)"
# Select positions [offset .. offset+NUM_TOKENS-1] for this chunk
cos_s = cos[offset:offset + NUM_TOKENS, :].to(device=x.device, dtype=x.dtype) # (NUM_TOKENS, D//2)
sin_s = sin[offset:offset + NUM_TOKENS, :].to(device=x.device, dtype=x.dtype) # (NUM_TOKENS, D//2)
# Broadcast to (1, 1, NUM_TOKENS, D//2, 1)
cos_s = cos_s.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
sin_s = sin_s.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
# Adjacent pairs: (BATCH, NUM_HEADS, NUM_TOKENS, D//2, 2)
x_pair = x.reshape(BATCH, NUM_HEADS, NUM_TOKENS, D // 2, 2)
x_even = x_pair[..., 0:1]
x_odd = x_pair[..., 1:2]
# Rotate each pair with its position-dependent angle
y_even = x_even * cos_s - x_odd * sin_s
y_odd = x_even * sin_s + x_odd * cos_s
# Re-pack pairs back to (BATCH, NUM_HEADS, NUM_TOKENS, D)
y = torch.cat([y_even, y_odd], dim=-1).reshape(BATCH, NUM_HEADS, NUM_TOKENS, D)
return y
if __name__ == "__main__":
torch.manual_seed(0)
BATCH, NUM_HEADS, NUM_TOKENS, D = 2, 3, 6, 8
x = torch.randn(BATCH, NUM_HEADS, NUM_TOKENS, D)
cos, sin, angles = compute_rope_params_adjacent(
head_dim=D,
theta_base=10_000.0,
context_length=NUM_TOKENS,
)
y = apply_rope_adjacent(x, cos, sin, offset=5)
print("Original Tensor (x) shape:", x.shape)
print("Rotated Tensor (y) shape:", y.shape)
print("Tensor [0, 0, 0] before RoPE:", x[0, 0, 0])
print("Tensor [0, 0, 0] after RoPE:", y[0, 0, 0])
Original Tensor (x) shape: torch.Size([2, 3, 6, 8])
Rotated Tensor (y) shape: torch.Size([2, 3, 6, 8])
Tensor [0, 0, 0] before RoPE: tensor([-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152])
Tensor [0, 0, 0] after RoPE: tensor([-1.4244, 0.7527, -0.0119, -0.5009, 0.8131, 0.7336, -0.3054, -2.1168])