The full implementation of what we'll cover today is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py
1. Understanding Rotary Position Embeddings (RoPE)
Before we delve into Su-scaled RoPE, let's first understand the basics of Rotary Position Embeddings.
RoPE is a technique that injects positional information into the model's token representations without adding extra tokens or increasing the model's parameter count. The key idea is to apply a rotation to each token's embedding based on its position in the sequence.
-
Frequency Calculation: For each dimension d in the embedding space, RoPE calculates a frequency:
inv_freq = 1 / (base ** (d / dim))
-
Position-Frequency Interaction: These frequencies are then multiplied by the token positions to create unique sinusoidal patterns for each position.
freqs = inv_freq @ position_ids.T
-
Rotation Application: The resulting patterns are used to rotate the token embeddings in 2D planes.
For a token at position pos, RoPE applies the following rotation:
x_rotated = [x * cos(pos * freq) - y * sin(pos * freq),
y * cos(pos * freq) + x * sin(pos * freq)]
Now that we understand RoPE, let's explore how Su-scaled RoPE builds upon and enhances this concept.
2. Understanding Su-RoPE
Su-RoPE extends RoPE by introducing scaling factors for different sequence length ranges.
freq = 1 / (SU_FACTOR * theta ** (d / dim))
This allows the model to better generalize to sequences longer than those seen during training.
-
Short and Long Factors: Two sets of scaling factors are used, one for shorter sequences and one for longer sequences.
-
Adaptive Scaling: The choice between short and long factors is made based on the sequence length.
Scaling Factor: An additional scaling factor is applied to adjust for the extended maximum position embeddings.
3. Implementing Su-scaled RoPE
Now that we understand the theory behind Su-scaled RoPE, let's implement it in code. We'll create a SuRoPE class that encapsulates all the functionality we've discussed:
import mlx.core as mx
import mlx.nn as nn
import math
class SuRoPE:
def __init__(self, config):
self.dim = config.hidden_size // config.num_attention_heads
self.original_max_position_embeddings = config.original_max_position_embeddings
self.rope_theta = config.rope_theta
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
self.long_factor = config.rope_scaling["long_factor"]
self.short_factor = config.rope_scaling["short_factor"]
def __call__(self, q, k, position_ids=None):
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
cos, sin = self._get_cos_sin(position_ids)
q = (q * cos) + (self._rotate_half(q) * sin)
k = (k * cos) + (self._rotate_half(k) * sin)
return q, k
def _get_cos_sin(self, position_ids):
su_factor = self.long_factor if mx.max(position_ids) > self.original_max_position_embeddings else self.short_factor
position_ids_expanded = position_ids[:, None, :]
inv_freq = 1.0 / (mx.array(su_factor, dtype=mx.float32) * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
return cos, sin
@staticmethod
def _rotate_half(x):
midpoint = x.shape[-1] // 2
x1, x2 = x[..., :midpoint], x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)
4. Integrating Su-scaled RoPE into Phi-3-Vision
Integrating our Su-scaled RoPE implementation into the Phi-3-Vision model is straightforward. We only need to add two lines to our Phi3Attention module:
class Phi3Attention(nn.Module):
def __init__(self, config):
# ...
self.rope = SuRoPE(config)
def __call__(self, x):
# ...
q, k = self.rope(q, k)
# ...
And now our ported model can handle up to 128K tokens!
Conclusion
In this tutorial, we implemented Su-scaled Rotary Position Embeddings (RoPE), enabling our model to handle sequences up to 128K tokens.
The full implementation is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py
Next, we'll explore efficient batching techniques to further optimize our Phi-3-Vision implementation in MLX.