@@ -86,14 +86,21 @@ def __init__(
8686 theta : float = 10000.0 ,
8787 num_learnable_registers : int = 128 ,
8888 rope_type : str = "interleaved" ,
89+ base_seq_len : int = 4096 ,
90+ double_precision : bool = True ,
8991 attention_kernel : str = "flash" ,
9092 mesh : jax .sharding .Mesh = None ,
9193 rngs : nnx .Rngs = None ,
9294 ):
9395 self .dim = input_dim
96+ self .heads = heads
97+ self .head_dim = head_dim
9498 self .theta = theta
9599 self .num_learnable_registers = num_learnable_registers
96100 self .num_layers = layers
101+ self .rope_type = rope_type
102+ self .base_seq_len = base_seq_len
103+ self .double_precision = double_precision
97104
98105 # 1. Initialize Stacked Layers using vmap
99106 # This creates a single module where parameters have an extra leading dimension [layers, ...]
@@ -165,15 +172,54 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
165172 new_mask = jnp .ones_like (attention_mask )
166173 return output , new_mask
167174
168- def _compute_1d_rope (self , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
169- t = jnp .arange (seq_len , dtype = jnp .float32 )
170- freqs = 1.0 / (self .theta ** (jnp .arange (0 , self .dim , 2 , dtype = jnp .float32 ) / self .dim ))
171- emb = jnp .outer (t , freqs )
172- cos = jnp .cos (emb )
173- sin = jnp .sin (emb )
174- cos = jnp .repeat (cos , 2 , axis = - 1 )
175- sin = jnp .repeat (sin , 2 , axis = - 1 )
176- return cos [None , ...], sin [None , ...]
175+ def _compute_1d_rope (self , batch_size : int , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
176+ grid_1d = jnp .arange (seq_len , dtype = jnp .float32 )
177+ grid_1d = grid_1d / self .base_seq_len
178+ grid = jnp .expand_dims (grid_1d , 0 )
179+ grid = jnp .tile (grid , (batch_size , 1 ))
180+
181+ num_rope_elems = 2
182+ freqs_dtype = jnp .float64 if self .double_precision else jnp .float32
183+ steps = self .dim // num_rope_elems
184+ pow_indices = jnp .power (self .theta , jnp .linspace (0.0 , 1.0 , steps , dtype = freqs_dtype ))
185+ base_freqs = (pow_indices * jnp .pi / 2.0 ).astype (jnp .float32 )
186+
187+ freqs = (jnp .expand_dims (grid , - 1 ) * 2.0 - 1.0 ) * base_freqs
188+
189+ cos_freqs = jnp .cos (freqs )
190+ sin_freqs = jnp .sin (freqs )
191+
192+ if self .rope_type == "interleaved" :
193+ cos_freqs = jnp .repeat (cos_freqs , 2 , axis = - 1 )
194+ sin_freqs = jnp .repeat (sin_freqs , 2 , axis = - 1 )
195+
196+ if self .dim % num_rope_elems != 0 :
197+ curr_dim = cos_freqs .shape [- 1 ]
198+ pad_amt = self .dim - curr_dim
199+ if pad_amt > 0 :
200+ cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_amt ), dtype = cos_freqs .dtype )
201+ sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_amt ), dtype = sin_freqs .dtype )
202+ cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
203+ sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
204+
205+ elif self .rope_type == "split" :
206+ expected_freqs = self .dim // 2
207+ current_freqs = freqs .shape [- 1 ]
208+ pad_size = expected_freqs - current_freqs
209+
210+ if pad_size > 0 :
211+ cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_size ), dtype = cos_freqs .dtype )
212+ sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_size ), dtype = sin_freqs .dtype )
213+ cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
214+ sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
215+
216+ b = cos_freqs .shape [0 ]
217+ t = cos_freqs .shape [1 ]
218+ h = self .heads
219+ cos_freqs = cos_freqs .reshape (b , t , h , - 1 ).transpose (0 , 2 , 1 , 3 )
220+ sin_freqs = sin_freqs .reshape (b , t , h , - 1 ).transpose (0 , 2 , 1 , 3 )
221+
222+ return cos_freqs , sin_freqs
177223
178224 def __call__ (
179225 self ,
@@ -198,8 +244,9 @@ def __call__(
198244 mean = jnp .mean (hidden_states ), std = jnp .std (hidden_states ))
199245
200246 # 2. RoPE
247+ batch_size = hidden_states .shape [0 ]
201248 seq_len = hidden_states .shape [1 ]
202- rotary_emb = self ._compute_1d_rope (seq_len , hidden_states .dtype )
249+ rotary_emb = self ._compute_1d_rope (batch_size , seq_len , hidden_states .dtype )
203250
204251 # 3. Transformer Blocks (Scan)
205252
0 commit comments