Skip to content

Commit 93e2feb

Browse files
Merge pull request #3113 from AI-Hypercomputer:jimmytsai/bringup-qwen2-5
PiperOrigin-RevId: 884762695
2 parents ab84f8e + d6c9842 commit 93e2feb

24 files changed

Lines changed: 851 additions & 262 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
107107
* Gemma 2 (2B, 9B, 27B)
108108
* Gemma 1 (2B, 7B)
109109
* Alibaba
110+
* Qwen 2.5 (7B, 14B)
110111
* Qwen 3 MoE 2507 (235B, 480B)
111112
* Qwen 3 MoE (30B, 235B)
112113
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The following models are supported:
1111
| **Gemma2** | 2B, 9B, 27B |||||
1212
| **Gemma3** (Multimodal) | 4B, 12B, 27B |||||
1313
| **Llama3.1** | 8B, 70B, 450B |||||
14+
| **Qwen2.5** | 7B, 14B |||||
1415
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B |||||
1516
| **Qwen3 MoE** | 30B, 235B, 480B |||||
1617
| **Mixtral** | 8x7B, 8x22B |||||

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -210,26 +210,45 @@
210210
query_pre_attn_scalar=144,
211211
)
212212

213-
qwen3_0_6b_config = transformers.Qwen3Config(
214-
vocab_size=151936,
215-
hidden_size=1024,
216-
intermediate_size=3072,
213+
qwen25_7b_config = transformers.Qwen2Config(
214+
vocab_size=152064,
215+
hidden_size=3584,
216+
intermediate_size=18944,
217217
num_hidden_layers=28,
218-
num_attention_heads=16,
218+
num_attention_heads=28,
219+
num_key_value_heads=4,
220+
hidden_act="silu",
221+
max_position_embeddings=32768,
222+
initializer_range=0.02,
223+
rms_norm_eps=1e-06,
224+
use_cache=True,
225+
rope_theta=1000000.0,
226+
tie_word_embeddings=False,
227+
torch_dtype="bfloat16",
228+
attention_bias=True,
229+
)
230+
231+
qwen25_14b_config = transformers.Qwen2Config(
232+
vocab_size=152064,
233+
hidden_size=5120,
234+
intermediate_size=13824,
235+
num_hidden_layers=48,
236+
num_attention_heads=40,
219237
num_key_value_heads=8,
220-
head_dim=128,
221238
hidden_act="silu",
222-
max_position_embeddings=40960,
223-
rms_norm_eps=1.0e-6,
239+
max_position_embeddings=32768,
240+
rms_norm_eps=1e-06,
224241
rope_theta=1000000.0,
225-
tie_word_embeddings=True,
242+
tie_word_embeddings=False,
226243
torch_dtype="bfloat16",
244+
attention_bias=True,
227245
)
228246

229-
qwen3_1_7b_config = transformers.Qwen3Config(
247+
248+
qwen3_0_6b_config = transformers.Qwen3Config(
230249
vocab_size=151936,
231-
hidden_size=2048,
232-
intermediate_size=6144,
250+
hidden_size=1024,
251+
intermediate_size=3072,
233252
num_hidden_layers=28,
234253
num_attention_heads=16,
235254
num_key_value_heads=8,
@@ -831,23 +850,19 @@
831850
"gemma3-4b": gemma3_4b_config,
832851
"gemma3-12b": gemma3_12b_config,
833852
"gemma3-27b": gemma3_27b_config,
853+
"qwen2.5-7b": qwen25_7b_config,
854+
"qwen2.5-14b": qwen25_14b_config,
834855
"qwen3-0.6b": qwen3_0_6b_config,
835-
"qwen3-1.7b": qwen3_1_7b_config,
836-
"qwen3-1.7b-base": qwen3_1_7b_config,
837856
"qwen3-4b": qwen3_4b_config,
838-
"qwen3-4b-base": qwen3_4b_config,
839857
"qwen3-4b-thinking-2507": qwen3_4b_config,
840858
"qwen3-8b": qwen3_8b_config,
841-
"qwen3-8b-base": qwen3_8b_config,
842859
"qwen3-14b": qwen3_14b_config,
843-
"qwen3-14b-base": qwen3_14b_config,
844860
"qwen3-32b": qwen3_32b_config,
845861
"llama3.1-8b": llama31_8b_config,
846862
"llama3.1-8b-Instruct": llama31_8b_config,
847863
"llama3.1-70b": llama31_70b_config,
848864
"llama3.1-405b": llama31_405b_config,
849865
"qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config,
850-
"qwen3-30b-a3b-base": qwen3_30b_a3b_thinking_2507_config,
851866
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
852867
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
853868
"deepseek3-671b": deepseek3_671b_config,

src/maxtext/checkpoint_conversion/utils/hf_shape.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
529529
return mapping
530530

531531

532-
def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
533-
"""Returns mapping between HuggingFace Qwen3 weights path and the HuggingFace weights shape.
532+
def QWEN_HF_WEIGHTS_TO_SHAPE(config):
533+
"""Returns mapping between HuggingFace Qwen weights path and the HuggingFace weights shape.
534534
535535
To check this mapping, dump the huggingface model shapes:
536536
from transformers import AutoModelForCausalLM
@@ -555,6 +555,7 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
555555
head_dim = config.get(
556556
"head_dim", config["hidden_size"] // config["num_attention_heads"]
557557
) # head_dim might not always be present
558+
attention_bias = config.get("attention_bias", False)
558559

559560
mapping = {
560561
"model.embed_tokens.weight": [config["vocab_size"], hidden_size],
@@ -580,6 +581,15 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
580581
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
581582
}
582583

584+
if attention_bias:
585+
layer_mapping.update(
586+
{
587+
f"{layer_prefix}.self_attn.q_proj.bias": [num_attention_heads * head_dim],
588+
f"{layer_prefix}.self_attn.k_proj.bias": [num_key_value_heads * head_dim],
589+
f"{layer_prefix}.self_attn.v_proj.bias": [num_key_value_heads * head_dim],
590+
}
591+
)
592+
583593
if num_experts > 1:
584594
# MoE MLP layers
585595
moe_ffn_intermediate_size = config.get("moe_intermediate_size")
@@ -756,18 +766,20 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
756766
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
757767
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
758768
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
759-
"qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE,
760-
"qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE,
761-
"qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE,
762-
"qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE,
763-
"qwen3-14b": QWEN3_HF_WEIGHTS_TO_SHAPE,
764-
"qwen3-32b": QWEN3_HF_WEIGHTS_TO_SHAPE,
769+
"qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE,
770+
"qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
771+
"qwen3-0.6b": QWEN_HF_WEIGHTS_TO_SHAPE,
772+
"qwen3-4b": QWEN_HF_WEIGHTS_TO_SHAPE,
773+
"qwen3-4b-thinking-2507": QWEN_HF_WEIGHTS_TO_SHAPE,
774+
"qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE,
775+
"qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
776+
"qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE,
765777
"llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
766778
"llama3.1-70b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
767779
"llama3.1-405b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
768-
"qwen3-30b-a3b": QWEN3_HF_WEIGHTS_TO_SHAPE,
769-
"qwen3-235b-a22b": QWEN3_HF_WEIGHTS_TO_SHAPE,
770-
"qwen3-480b-a35b": QWEN3_HF_WEIGHTS_TO_SHAPE,
780+
"qwen3-30b-a3b": QWEN_HF_WEIGHTS_TO_SHAPE,
781+
"qwen3-235b-a22b": QWEN_HF_WEIGHTS_TO_SHAPE,
782+
"qwen3-480b-a35b": QWEN_HF_WEIGHTS_TO_SHAPE,
771783
"deepseek3-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE,
772784
"gpt-oss-20b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,
773785
"gpt-oss-120b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
587587
return mapping
588588

589589

590-
def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
591-
"""Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
590+
def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
591+
"""Returns mapping from MaxText to HuggingFace Qwen weight paths.
592592
593593
This function generates a dictionary that maps parameter names from a MaxText
594-
Qwen3 checkpoint to their corresponding names in the Hugging Face format.
594+
Qwen checkpoint to their corresponding names in the Hugging Face format.
595595
It handles both dense and Mixture-of-Experts (MoE) model variants.
596596
597597
Args:
@@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
631631
"params-decoder-layers-self_attention-value-kernel": [
632632
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
633633
],
634+
"params-decoder-layers-self_attention-query-bias": [
635+
f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers)
636+
],
637+
"params-decoder-layers-self_attention-key-bias": [
638+
f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers)
639+
],
640+
"params-decoder-layers-self_attention-value-bias": [
641+
f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers)
642+
],
634643
"params-decoder-layers-self_attention-out-kernel": [
635644
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
636645
],
@@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
688697
f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
689698
f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
690699
f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
700+
f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
701+
f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
702+
f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
691703
f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
692704
f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
693705
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
@@ -721,8 +733,8 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
721733
return mapping
722734

723735

724-
def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
725-
"""Creates parameter transformation functions for Qwen3.
736+
def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
737+
"""Creates parameter transformation functions for Qwen.
726738
727739
This function provides a dictionary of transformation functions (hooks) for
728740
converting Qwen3 model parameters between MaxText and Hugging Face formats.
@@ -766,6 +778,15 @@ def reshape_kernel(input_tensor, target_shape):
766778
else:
767779
return input_tensor.T.reshape(target_shape)
768780

781+
def reshape_bias(input_tensor, target_shape=None):
782+
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
783+
if saving_to_hf:
784+
# MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
785+
return input_tensor.reshape(target_shape)
786+
else:
787+
# HF [hidden_dim] -> MaxText [heads, head_dim]
788+
return input_tensor.reshape(target_shape)
789+
769790
mapping = {
770791
"params-token_embedder-embedding": pad_embedding_layer,
771792
"params-decoder-logits_dense-kernel": reshape_kernel,
@@ -780,6 +801,11 @@ def reshape_kernel(input_tensor, target_shape):
780801
"mlp-wi_1-kernel",
781802
"mlp-wo-kernel",
782803
]
804+
bias_hooks = [
805+
"self_attention-query-bias",
806+
"self_attention-key-bias",
807+
"self_attention-value-bias",
808+
]
783809
moe_kernel_hooks = [
784810
"moe_block-gate-kernel",
785811
"moe_block-wi_0-kernel",
@@ -793,13 +819,17 @@ def reshape_kernel(input_tensor, target_shape):
793819
if scan_layers:
794820
for key in kernel_hooks:
795821
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
822+
for key in bias_hooks:
823+
mapping[f"params-decoder-layers-{key}"] = reshape_bias
796824
if num_experts > 1:
797825
for key in moe_kernel_hooks:
798826
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
799827
else:
800828
for i in range(n_layers):
801829
for key in kernel_hooks:
802830
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
831+
for key in bias_hooks:
832+
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias
803833
if num_experts > 1:
804834
for key in moe_kernel_hooks:
805835
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
@@ -1376,7 +1406,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
13761406
# Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
13771407
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
13781408
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
1379-
text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(
1409+
text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
13801410
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
13811411
maxtext_config=maxtext_config,
13821412
scan_layers=scan_layers,
@@ -1544,7 +1574,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
15441574
# Text hooks, reusing QWEN3-MOE hook function
15451575
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
15461576
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
1547-
text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
1577+
text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
15481578
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
15491579
maxtext_config=maxtext_config,
15501580
scan_layers=scan_layers,
@@ -2332,24 +2362,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23322362
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
23332363
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
23342364
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
2335-
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2336-
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2337-
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2338-
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2339-
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2340-
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2341-
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2342-
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2343-
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2344-
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2345-
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2365+
"qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2366+
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2367+
"qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2368+
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2369+
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2370+
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2371+
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2372+
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2373+
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2374+
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2375+
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
23462376
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
23472377
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
23482378
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
2349-
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2350-
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2351-
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2352-
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
2379+
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2380+
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
2381+
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
23532382
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
23542383
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
23552384
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
@@ -2370,24 +2399,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23702399
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23712400
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23722401
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2373-
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2374-
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2375-
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2376-
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2377-
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2378-
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2379-
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2380-
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2381-
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2382-
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2383-
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2402+
"qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2403+
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2404+
"qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2405+
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2406+
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2407+
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2408+
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2409+
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2410+
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2411+
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2412+
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23842413
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23852414
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23862415
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2387-
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2388-
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2389-
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2390-
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2416+
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2417+
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
2418+
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23912419
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
23922420
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
23932421
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum):
9292
GEMMA = "gemma"
9393
GEMMA2 = "gemma2"
9494
GEMMA3 = "gemma3"
95+
QWEN2 = "qwen2"
9596
QWEN3 = "qwen3"
9697
QWEN3_MOE = "qwen3_moe"
9798
QWEN3_NEXT = "qwen3_next"

0 commit comments

Comments
 (0)