@@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
587587 return mapping
588588
589589
590- def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591- """Returns mapping from MaxText to HuggingFace Qwen weight paths.
590+ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591+ """Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
592592
593593 This function generates a dictionary that maps parameter names from a MaxText
594- Qwen checkpoint to their corresponding names in the Hugging Face format.
594+ Qwen3 checkpoint to their corresponding names in the Hugging Face format.
595595 It handles both dense and Mixture-of-Experts (MoE) model variants.
596596
597597 Args:
@@ -631,15 +631,6 @@ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
631631 "params-decoder-layers-self_attention-value-kernel" : [
632632 f"model.layers.{ i } .self_attn.v_proj.weight" for i in range (n_layers )
633633 ],
634- "params-decoder-layers-self_attention-query-bias" : [
635- f"model.layers.{ i } .self_attn.q_proj.bias" for i in range (n_layers )
636- ],
637- "params-decoder-layers-self_attention-key-bias" : [
638- f"model.layers.{ i } .self_attn.k_proj.bias" for i in range (n_layers )
639- ],
640- "params-decoder-layers-self_attention-value-bias" : [
641- f"model.layers.{ i } .self_attn.v_proj.bias" for i in range (n_layers )
642- ],
643634 "params-decoder-layers-self_attention-out-kernel" : [
644635 f"model.layers.{ i } .self_attn.o_proj.weight" for i in range (n_layers )
645636 ],
@@ -697,9 +688,6 @@ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
697688 f"params-decoder-layers_{ i } -self_attention-key-kernel" : f"model.layers.{ i } .self_attn.k_proj.weight" ,
698689 f"params-decoder-layers_{ i } -self_attention-value-kernel" : f"model.layers.{ i } .self_attn.v_proj.weight" ,
699690 f"params-decoder-layers_{ i } -self_attention-out-kernel" : f"model.layers.{ i } .self_attn.o_proj.weight" ,
700- f"params-decoder-layers_{ i } -self_attention-query-bias" : f"model.layers.{ i } .self_attn.q_proj.bias" ,
701- f"params-decoder-layers_{ i } -self_attention-key-bias" : f"model.layers.{ i } .self_attn.k_proj.bias" ,
702- f"params-decoder-layers_{ i } -self_attention-value-bias" : f"model.layers.{ i } .self_attn.v_proj.bias" ,
703691 f"params-decoder-layers_{ i } -self_attention-query_norm-scale" : f"model.layers.{ i } .self_attn.q_norm.weight" ,
704692 f"params-decoder-layers_{ i } -self_attention-key_norm-scale" : f"model.layers.{ i } .self_attn.k_norm.weight" ,
705693 f"params-decoder-layers_{ i } -post_self_attention_layer_norm-scale" : f"model.layers.{ i } .post_attention_layernorm.weight" ,
@@ -733,8 +721,8 @@ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
733721 return mapping
734722
735723
736- def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
737- """Creates parameter transformation functions for Qwen .
724+ def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
725+ """Creates parameter transformation functions for Qwen3 .
738726
739727 This function provides a dictionary of transformation functions (hooks) for
740728 converting Qwen3 model parameters between MaxText and Hugging Face formats.
@@ -778,15 +766,6 @@ def reshape_kernel(input_tensor, target_shape):
778766 else :
779767 return input_tensor .T .reshape (target_shape )
780768
781- def reshape_bias (input_tensor , target_shape = None ):
782- """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
783- if saving_to_hf :
784- # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
785- return input_tensor .reshape (target_shape )
786- else :
787- # HF [hidden_dim] -> MaxText [heads, head_dim]
788- return input_tensor .reshape (target_shape )
789-
790769 mapping = {
791770 "params-token_embedder-embedding" : pad_embedding_layer ,
792771 "params-decoder-logits_dense-kernel" : reshape_kernel ,
@@ -801,11 +780,6 @@ def reshape_bias(input_tensor, target_shape=None):
801780 "mlp-wi_1-kernel" ,
802781 "mlp-wo-kernel" ,
803782 ]
804- bias_hooks = [
805- "self_attention-query-bias" ,
806- "self_attention-key-bias" ,
807- "self_attention-value-bias" ,
808- ]
809783 moe_kernel_hooks = [
810784 "moe_block-gate-kernel" ,
811785 "moe_block-wi_0-kernel" ,
@@ -819,17 +793,13 @@ def reshape_bias(input_tensor, target_shape=None):
819793 if scan_layers :
820794 for key in kernel_hooks :
821795 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
822- for key in bias_hooks :
823- mapping [f"params-decoder-layers-{ key } " ] = reshape_bias
824796 if num_experts > 1 :
825797 for key in moe_kernel_hooks :
826798 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
827799 else :
828800 for i in range (n_layers ):
829801 for key in kernel_hooks :
830802 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
831- for key in bias_hooks :
832- mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_bias
833803 if num_experts > 1 :
834804 for key in moe_kernel_hooks :
835805 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
@@ -1406,7 +1376,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
14061376 # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
14071377 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
14081378 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1409- text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
1379+ text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (
14101380 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
14111381 maxtext_config = maxtext_config ,
14121382 scan_layers = scan_layers ,
@@ -1574,7 +1544,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
15741544 # Text hooks, reusing QWEN3-MOE hook function
15751545 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
15761546 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1577- text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (
1547+ text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (
15781548 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
15791549 maxtext_config = maxtext_config ,
15801550 scan_layers = scan_layers ,
@@ -2362,23 +2332,24 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23622332 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23632333 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23642334 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2365- "qwen2.5 -0.5b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2366- "qwen2.5 -1.5b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2367- "qwen2.5-3b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2368- "qwen2.5-7b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2369- "qwen2.5-14b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2370- "qwen3-0.6b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2371- "qwen3-4b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2372- "qwen3-4b-thinking-2507 " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2373- "qwen3-8b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2374- "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2375- "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2335+ "qwen3 -0.6b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2336+ "qwen3 -1.7b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2337+ "qwen3-1.7b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2338+ "qwen3-4b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2339+ "qwen3-4b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2340+ "qwen3-4b-thinking-2507 " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2341+ "qwen3-8b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2342+ "qwen3-8b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2343+ "qwen3-14b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2344+ "qwen3-14b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2345+ "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23762346 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23772347 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23782348 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
2379- "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2380- "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2381- "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2349+ "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2350+ "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2351+ "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2352+ "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23822353 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING ,
23832354 "gpt-oss-20b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
23842355 "gpt-oss-120b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
@@ -2399,23 +2370,24 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23992370 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24002371 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24012372 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2402- "qwen2.5 -0.5b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2403- "qwen2.5 -1.5b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2404- "qwen2.5-3b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2405- "qwen2.5-7b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2406- "qwen2.5-14b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2407- "qwen3-0.6b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2408- "qwen3-4b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2409- "qwen3-4b-thinking-2507 " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2410- "qwen3-8b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2411- "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2412- "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2373+ "qwen3 -0.6b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2374+ "qwen3 -1.7b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2375+ "qwen3-1.7b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2376+ "qwen3-4b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2377+ "qwen3-4b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2378+ "qwen3-4b-thinking-2507 " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2379+ "qwen3-8b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2380+ "qwen3-8b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2381+ "qwen3-14b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2382+ "qwen3-14b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2383+ "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24132384 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24142385 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24152386 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2416- "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2417- "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2418- "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2387+ "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2388+ "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2389+ "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2390+ "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24192391 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
24202392 "gpt-oss-20b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
24212393 "gpt-oss-120b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
0 commit comments