Skip to content

Commit 3f928e4

Browse files
committed
Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE
1 parent 1a84208 commit 3f928e4

13 files changed

Lines changed: 320 additions & 24 deletions

File tree

src/MaxText/layers/attentions.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
)
6868
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init
6969
from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes
70-
from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm
70+
from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm, GlobalRMSNorm
7171
from MaxText.layers.quantizations import AqtQuantization as Quant
7272
from maxtext.inference import kvcache, page_manager, paged_attention
7373
from maxtext.inference.kvcache import KVQuant
@@ -164,6 +164,7 @@ def attention_as_linen(
164164
use_mrope: bool = False,
165165
mrope_section: tuple[int, int, int] | None = None,
166166
name: str | None = None,
167+
rope_type: str | None = None,
167168
):
168169
"""A factory function to create an Attention as a Linen module.
169170
@@ -228,6 +229,7 @@ def attention_as_linen(
228229
use_mrope=use_mrope,
229230
mrope_section=mrope_section,
230231
name=name,
232+
rope_type=rope_type,
231233
metadata_fn=variable_to_logically_partitioned,
232234
abstract_init=False,
233235
)
@@ -328,6 +330,7 @@ def __init__(
328330
use_mrope: bool = False,
329331
mrope_section: tuple[int, int, int] | None = None,
330332
name: str | None = None,
333+
rope_type: str | None = None,
331334
rngs: Optional[nnx.Rngs] = None,
332335
):
333336
"""Initializes the Attention module.
@@ -367,6 +370,8 @@ def __init__(
367370
is_vision: Whether this is a vision attention layer.
368371
model_mode: The model's operational mode (e.g., 'train', 'prefill').
369372
base_kv_cache: Whether to use base (non-MLA) kv cache, if KVCache is used
373+
rope_type: Optional override for the RoPE type (e.g., 'default', 'yarn').
374+
If provided, this takes precedence over `config.rope_type`.
370375
rngs: RNG state for initialization, passed by the nnx.to_linen wrapper.
371376
"""
372377

@@ -424,6 +429,8 @@ def __init__(
424429
self.use_mrope = use_mrope
425430
self.mrope_section = mrope_section
426431
self.rngs = rngs
432+
# Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config.
433+
self.rope_type = (rope_type or self.config.rope_type).lower()
427434

428435
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT
429436

@@ -490,18 +497,28 @@ def __init__(
490497
self.sinks = None
491498

492499
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
500+
493501
if self.use_qk_norm and not is_llama4_decoder_block:
494-
self.query_norm = RMSNorm(
495-
num_features=self.head_dim,
502+
# Check if this is Olmo3, which uses a unique "Global" QK Norm strategy.
503+
# GlobalRMSNorm flattens (Heads, Dim) to normalize across the entire hidden state.
504+
use_global_qk_norm = self.config.model_name.startswith("olmo3")
505+
qk_norm_cls = GlobalRMSNorm if use_global_qk_norm else RMSNorm
506+
507+
# For RMSNorm use `head_dim` (per-head normalization), while for GlobalRMSNorm use `num_heads * head_dim` (global normalization).
508+
q_features = (self.num_query_heads * self.head_dim) if use_global_qk_norm else self.head_dim
509+
k_features = (self.num_kv_heads * self.head_dim) if use_global_qk_norm else self.head_dim
510+
511+
self.query_norm = qk_norm_cls(
512+
num_features=q_features,
496513
dtype=self.config.dtype,
497514
weight_dtype=self.config.weight_dtype,
498515
shard_mode=self.config.shard_mode,
499516
epsilon=self.config.normalization_layer_epsilon,
500517
kernel_axes=("norm",),
501518
rngs=self.rngs,
502519
)
503-
self.key_norm = RMSNorm(
504-
num_features=self.head_dim,
520+
self.key_norm = qk_norm_cls(
521+
num_features=k_features,
505522
dtype=self.config.dtype,
506523
weight_dtype=self.config.weight_dtype,
507524
shard_mode=self.config.shard_mode,
@@ -726,7 +743,7 @@ def init_rotary_embedding(self):
726743
else:
727744
rope_embedding_dims = self.head_dim
728745

729-
rope_type = self.config.rope_type.lower()
746+
rope_type = self.rope_type
730747
rope_use_scale = self.config.rope_use_scale
731748
if self.is_vision:
732749
if self.config.model_name.startswith("qwen3-omni"):

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,8 @@ def __call__(
902902
layer_kwargs = {"layer_idx": lyr}
903903
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
904904
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
905+
if cfg.decoder_block == DecoderBlockType.OLMO3:
906+
layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)}
905907
layer = RemattedBlockLayer(
906908
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
907909
)

src/MaxText/layers/normalizations.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
8080
return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding)
8181

8282

83+
class GlobalRMSNorm(RMSNorm):
84+
"""
85+
Applies RMSNorm over the last two dimensions (Heads * HeadDim).
86+
Used for Olmo3 which normalizes across all heads combined.
87+
"""
88+
89+
def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> jnp.ndarray:
90+
# x shape: [..., Heads, HeadDim]
91+
input_shape = x.shape
92+
93+
# Flatten the last two dimensions: [..., Heads * HeadDim]
94+
# We use -2 and -1 to ensure we capture the last two dims regardless of rank
95+
flattened_shape = input_shape[:-2] + (input_shape[-2] * input_shape[-1],)
96+
x_flat = x.reshape(flattened_shape)
97+
98+
# Apply standard RMSNorm (which normalizes over the last axis)
99+
y_flat = super().__call__(x_flat, out_sharding)
100+
101+
# Reshape back to [..., Heads, HeadDim]
102+
return y_flat.reshape(input_shape)
103+
104+
83105
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
84106
"""
85107
Used for input and post attention layernorms

src/MaxText/layers/olmo3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def __init__(
9898
rngs=rngs,
9999
)
100100

101+
current_rope_type = config.rope_type.lower()
102+
if self.attention_type == attentions.AttentionType.LOCAL_SLIDING:
103+
current_rope_type = "default"
104+
101105
# Self-attention block
102106
self.attention = Attention(
103107
config=config,
@@ -121,6 +125,7 @@ def __init__(
121125
query_pre_attn_scalar=(config.head_dim**-0.5),
122126
model_mode=model_mode,
123127
use_qk_norm=config.use_qk_norm,
128+
rope_type=current_rope_type,
124129
rngs=rngs,
125130
)
126131

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ class LazyHFLoader:
111111
can still occur in parallel.
112112
"""
113113

114-
def __init__(self, model_id, token):
114+
def __init__(self, model_id, token, revision=None):
115115
self.model_id = model_id
116116
self.token = token
117+
self.revision = revision
117118
# Whether loads from local directory
118119
self.is_local = os.path.isdir(self.model_id)
119120
self.shard_map = {}
@@ -156,7 +157,7 @@ def _initialize_index(self):
156157
if self.is_local:
157158
index_path = os.path.join(self.model_id, index_file)
158159
else:
159-
index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token)
160+
index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token, revision=self.revision)
160161
with open(index_path, "r", encoding="utf-8") as f:
161162
index_data = json.load(f)
162163
self.shard_map = index_data["weight_map"]
@@ -186,7 +187,7 @@ def get_tensor(self, key: str) -> np.ndarray:
186187
else:
187188
# STEP 1: Download outside the lock.
188189
# multiple threads can download different shards at the same time.
189-
local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token)
190+
local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token, revision=self.revision)
190191

191192
# STEP 2: Lock ONLY the reading into RAM.
192193
# This prevents multiple threads from simultaneously allocating large chunks of RAM.
@@ -574,7 +575,7 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
574575
output_directory = config.base_output_directory
575576

576577
hf_token = config.hf_access_token
577-
578+
revision = test_args.revision
578579
use_lazy_load = test_args.lazy_load_tensors
579580

580581
if use_lazy_load and config.use_multimodal:
@@ -586,14 +587,14 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
586587
# Define the appropriate tensor getter based on mode
587588
if use_lazy_load:
588589
max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...")
589-
hf_loader = LazyHFLoader(model_id, hf_token)
590-
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token)
590+
hf_loader = LazyHFLoader(model_id, hf_token, revision=revision)
591+
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision)
591592
print_ram_usage("After LazyLoader init")
592593
tensor_getter = hf_loader.get_tensor
593594
else:
594595
max_logging.log(f"Lazy loading DISABLED. Loading full HuggingFace model: {model_id}...")
595-
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token)
596-
hf_model = get_hf_model(model_id, token=hf_token)
596+
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision)
597+
hf_model = get_hf_model(model_id, token=hf_token, revision=revision)
597598
hf_state_dict_numpy = hf_model.state_dict()
598599
# Convert all to numpy immediately in eager mode
599600
for k, v in hf_state_dict_numpy.items():
@@ -729,6 +730,10 @@ def _eager_getter(key):
729730
# storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk
730731
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
731732

733+
parser.add_argument(
734+
"--revision", type=str, required=False, default=None, help="Specific Hugging Face revision (branch/tag/commit)"
735+
)
736+
732737
# Parse local arguments
733738
# Parse known args returns the namespace AND the list of remaining arguments
734739
local_args, remaining_args = parser.parse_known_args()

0 commit comments

Comments
 (0)