Skip to content

Commit 34e32a1

Browse files
committed
Onboard Olmo3 config and decoder layer
1 parent 96f1375 commit 34e32a1

7 files changed

Lines changed: 422 additions & 0 deletions

File tree

src/MaxText/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class DecoderBlockType(enum.Enum):
100100
SIMPLE = "simple"
101101
SIMPLE_MLP = "simple_mlp"
102102
LLAMA4 = "llama4"
103+
OLMO3 = "olmo3"
103104

104105

105106
class AttentionType(enum.Enum):
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# AllenAI OLMo 3 32B Configuration
16+
# https://huggingface.co/allenai/Olmo-3.1-32B-Instruct/blob/main/config.json
17+
18+
model_name: "olmo3_32b"
19+
decoder_block: "olmo3"
20+
21+
# Model Dimensions
22+
base_emb_dim: 5120
23+
base_num_query_heads: 40
24+
base_num_kv_heads: 8
25+
base_mlp_dim: 27648
26+
base_num_decoder_layers: 64
27+
head_dim: 128
28+
29+
# Activations & Normalization
30+
mlp_activations: ["silu", "linear"]
31+
normalization_layer_epsilon: 1.e-6
32+
use_qk_norm: True
33+
34+
# Attention
35+
# Layers 0,1,2 use sliding window 4096. Layer 3 uses global. Repeats.
36+
sliding_window_size: 4096
37+
inhomogeneous_layer_cycle_interval: 4
38+
39+
# RoPE (YaRN)
40+
rope_type: "yarn"
41+
rope_max_timescale: 500000 # rope_theta
42+
rope_factor: 8.0 # factor so 0.1 * ln(rope_factor) + 1.0 = 1.2079441541679836
43+
original_max_position_embeddings: 8192
44+
beta_fast: 32.0
45+
beta_slow: 1.0
46+
max_position_embeddings: 65536
47+
rope_attention_scaling: True
48+
49+
# Embeddings
50+
vocab_size: 100278
51+
logits_via_embedding: False
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# AllenAI OLMo 3 7B Configuration
16+
# https://huggingface.co/allenai/Olmo-3-7B-Instruct
17+
18+
model_name: "olmo3_7b"
19+
decoder_block: "olmo3"
20+
21+
# Model Dimensions
22+
base_emb_dim: 4096
23+
base_num_query_heads: 32
24+
base_num_kv_heads: 32
25+
base_mlp_dim: 11008
26+
base_num_decoder_layers: 32
27+
head_dim: 128
28+
29+
# Activations & Normalization
30+
mlp_activations: ["silu", "linear"] # SwiGLU
31+
normalization_layer_epsilon: 1.e-6
32+
use_qk_norm: True
33+
34+
# Attention
35+
# Layers 0,1,2 use sliding window 4096. Layer 3 uses global. Repeats.
36+
sliding_window_size: 4096
37+
inhomogeneous_layer_cycle_interval: 4
38+
39+
# RoPE
40+
rope_type: "yarn"
41+
rope_max_timescale: 500000 # rope_theta
42+
rope_factor: 8.0 # factor so 0.1 * ln(rope_factor) + 1.0 = 1.2079441541679836
43+
original_max_position_embeddings: 8192
44+
beta_fast: 32.0
45+
beta_slow: 1.0
46+
max_position_embeddings: 65536
47+
rope_attention_scaling: True
48+
49+
# Embeddings
50+
vocab_size: 100278
51+
logits_via_embedding: False

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ class ProfilerType(str, Enum):
236236
"gpt-oss-120b",
237237
"llama4-17b-16e",
238238
"llama4-17b-128e",
239+
"olmo3_7b",
240+
"olmo3_32b",
239241
]
240242

241243

src/MaxText/layers/decoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
mixtral,
6060
qwen3,
6161
simple_layer,
62+
olmo3,
6263
)
6364

6465
# ------------------------------------------------------------------------------
@@ -430,6 +431,9 @@ def get_decoder_layers(self):
430431
return [simple_layer.SimpleMlpDecoderLayerToLinen]
431432
case DecoderBlockType.LLAMA4:
432433
return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen]
434+
case DecoderBlockType.OLMO3:
435+
return [olmo3.Olmo3ScannableBlockToLinen] if self.config.scan_layers else [olmo3.Olmo3DecoderLayerToLinen]
436+
433437
case _:
434438
# Default case to handle any unknown decoder block types.
435439
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
@@ -479,6 +483,7 @@ def get_norm_layer(self, num_features: int):
479483
DecoderBlockType.SIMPLE,
480484
DecoderBlockType.SIMPLE_MLP,
481485
DecoderBlockType.LLAMA4,
486+
DecoderBlockType.OLMO3,
482487
):
483488
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
484489
elif self.config.decoder_block == DecoderBlockType.GPT3:

0 commit comments

Comments
 (0)