|
67 | 67 | ) |
68 | 68 | from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init |
69 | 69 | from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes |
70 | | -from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm |
| 70 | +from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm, GlobalRMSNorm |
71 | 71 | from MaxText.layers.quantizations import AqtQuantization as Quant |
72 | 72 | from maxtext.inference import kvcache, page_manager, paged_attention |
73 | 73 | from maxtext.inference.kvcache import KVQuant |
@@ -164,6 +164,7 @@ def attention_as_linen( |
164 | 164 | use_mrope: bool = False, |
165 | 165 | mrope_section: tuple[int, int, int] | None = None, |
166 | 166 | name: str | None = None, |
| 167 | + rope_type: str | None = None, |
167 | 168 | ): |
168 | 169 | """A factory function to create an Attention as a Linen module. |
169 | 170 |
|
@@ -228,6 +229,7 @@ def attention_as_linen( |
228 | 229 | use_mrope=use_mrope, |
229 | 230 | mrope_section=mrope_section, |
230 | 231 | name=name, |
| 232 | + rope_type=rope_type, |
231 | 233 | metadata_fn=variable_to_logically_partitioned, |
232 | 234 | abstract_init=False, |
233 | 235 | ) |
@@ -328,6 +330,7 @@ def __init__( |
328 | 330 | use_mrope: bool = False, |
329 | 331 | mrope_section: tuple[int, int, int] | None = None, |
330 | 332 | name: str | None = None, |
| 333 | + rope_type: str | None = None, |
331 | 334 | rngs: Optional[nnx.Rngs] = None, |
332 | 335 | ): |
333 | 336 | """Initializes the Attention module. |
@@ -367,6 +370,8 @@ def __init__( |
367 | 370 | is_vision: Whether this is a vision attention layer. |
368 | 371 | model_mode: The model's operational mode (e.g., 'train', 'prefill'). |
369 | 372 | base_kv_cache: Whether to use base (non-MLA) kv cache, if KVCache is used |
| 373 | + rope_type: Optional override for the RoPE type (e.g., 'default', 'yarn'). |
| 374 | + If provided, this takes precedence over `config.rope_type`. |
370 | 375 | rngs: RNG state for initialization, passed by the nnx.to_linen wrapper. |
371 | 376 | """ |
372 | 377 |
|
@@ -424,6 +429,8 @@ def __init__( |
424 | 429 | self.use_mrope = use_mrope |
425 | 430 | self.mrope_section = mrope_section |
426 | 431 | self.rngs = rngs |
| 432 | + # Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config. |
| 433 | + self.rope_type = (rope_type or self.config.rope_type).lower() |
427 | 434 |
|
428 | 435 | self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT |
429 | 436 |
|
@@ -490,18 +497,28 @@ def __init__( |
490 | 497 | self.sinks = None |
491 | 498 |
|
492 | 499 | is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4 |
| 500 | + |
493 | 501 | if self.use_qk_norm and not is_llama4_decoder_block: |
494 | | - self.query_norm = RMSNorm( |
495 | | - num_features=self.head_dim, |
| 502 | + # Check if this is Olmo3, which uses a unique "Global" QK Norm strategy. |
| 503 | + # GlobalRMSNorm flattens (Heads, Dim) to normalize across the entire hidden state. |
| 504 | + use_global_qk_norm = self.config.model_name.startswith("olmo3") |
| 505 | + qk_norm_cls = GlobalRMSNorm if use_global_qk_norm else RMSNorm |
| 506 | + |
| 507 | + # For RMSNorm use `head_dim` (per-head normalization), while for GlobalRMSNorm use `num_heads * head_dim` (global normalization). |
| 508 | + q_features = (self.num_query_heads * self.head_dim) if use_global_qk_norm else self.head_dim |
| 509 | + k_features = (self.num_kv_heads * self.head_dim) if use_global_qk_norm else self.head_dim |
| 510 | + |
| 511 | + self.query_norm = qk_norm_cls( |
| 512 | + num_features=q_features, |
496 | 513 | dtype=self.config.dtype, |
497 | 514 | weight_dtype=self.config.weight_dtype, |
498 | 515 | shard_mode=self.config.shard_mode, |
499 | 516 | epsilon=self.config.normalization_layer_epsilon, |
500 | 517 | kernel_axes=("norm",), |
501 | 518 | rngs=self.rngs, |
502 | 519 | ) |
503 | | - self.key_norm = RMSNorm( |
504 | | - num_features=self.head_dim, |
| 520 | + self.key_norm = qk_norm_cls( |
| 521 | + num_features=k_features, |
505 | 522 | dtype=self.config.dtype, |
506 | 523 | weight_dtype=self.config.weight_dtype, |
507 | 524 | shard_mode=self.config.shard_mode, |
@@ -726,7 +743,7 @@ def init_rotary_embedding(self): |
726 | 743 | else: |
727 | 744 | rope_embedding_dims = self.head_dim |
728 | 745 |
|
729 | | - rope_type = self.config.rope_type.lower() |
| 746 | + rope_type = self.rope_type |
730 | 747 | rope_use_scale = self.config.rope_use_scale |
731 | 748 | if self.is_vision: |
732 | 749 | if self.config.model_name.startswith("qwen3-omni"): |
|
0 commit comments