@@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
587587 return mapping
588588
589589
590- def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591- """Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
590+ def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
591+ """Returns mapping from MaxText to HuggingFace Qwen weight paths.
592592
593593 This function generates a dictionary that maps parameter names from a MaxText
594- Qwen3 checkpoint to their corresponding names in the Hugging Face format.
594+ Qwen checkpoint to their corresponding names in the Hugging Face format.
595595 It handles both dense and Mixture-of-Experts (MoE) model variants.
596596
597597 Args:
@@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
631631 "params-decoder-layers-self_attention-value-kernel" : [
632632 f"model.layers.{ i } .self_attn.v_proj.weight" for i in range (n_layers )
633633 ],
634+ "params-decoder-layers-self_attention-query-bias" : [
635+ f"model.layers.{ i } .self_attn.q_proj.bias" for i in range (n_layers )
636+ ],
637+ "params-decoder-layers-self_attention-key-bias" : [
638+ f"model.layers.{ i } .self_attn.k_proj.bias" for i in range (n_layers )
639+ ],
640+ "params-decoder-layers-self_attention-value-bias" : [
641+ f"model.layers.{ i } .self_attn.v_proj.bias" for i in range (n_layers )
642+ ],
634643 "params-decoder-layers-self_attention-out-kernel" : [
635644 f"model.layers.{ i } .self_attn.o_proj.weight" for i in range (n_layers )
636645 ],
@@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
688697 f"params-decoder-layers_{ i } -self_attention-key-kernel" : f"model.layers.{ i } .self_attn.k_proj.weight" ,
689698 f"params-decoder-layers_{ i } -self_attention-value-kernel" : f"model.layers.{ i } .self_attn.v_proj.weight" ,
690699 f"params-decoder-layers_{ i } -self_attention-out-kernel" : f"model.layers.{ i } .self_attn.o_proj.weight" ,
700+ f"params-decoder-layers_{ i } -self_attention-query-bias" : f"model.layers.{ i } .self_attn.q_proj.bias" ,
701+ f"params-decoder-layers_{ i } -self_attention-key-bias" : f"model.layers.{ i } .self_attn.k_proj.bias" ,
702+ f"params-decoder-layers_{ i } -self_attention-value-bias" : f"model.layers.{ i } .self_attn.v_proj.bias" ,
691703 f"params-decoder-layers_{ i } -self_attention-query_norm-scale" : f"model.layers.{ i } .self_attn.q_norm.weight" ,
692704 f"params-decoder-layers_{ i } -self_attention-key_norm-scale" : f"model.layers.{ i } .self_attn.k_norm.weight" ,
693705 f"params-decoder-layers_{ i } -post_self_attention_layer_norm-scale" : f"model.layers.{ i } .post_attention_layernorm.weight" ,
@@ -721,8 +733,8 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
721733 return mapping
722734
723735
724- def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
725- """Creates parameter transformation functions for Qwen3 .
736+ def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
737+ """Creates parameter transformation functions for Qwen .
726738
727739 This function provides a dictionary of transformation functions (hooks) for
728740 converting Qwen3 model parameters between MaxText and Hugging Face formats.
@@ -766,6 +778,15 @@ def reshape_kernel(input_tensor, target_shape):
766778 else :
767779 return input_tensor .T .reshape (target_shape )
768780
781+ def reshape_bias (input_tensor , target_shape = None ):
782+ """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
783+ if saving_to_hf :
784+ # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
785+ return input_tensor .reshape (target_shape )
786+ else :
787+ # HF [hidden_dim] -> MaxText [heads, head_dim]
788+ return input_tensor .reshape (target_shape )
789+
769790 mapping = {
770791 "params-token_embedder-embedding" : pad_embedding_layer ,
771792 "params-decoder-logits_dense-kernel" : reshape_kernel ,
@@ -780,6 +801,11 @@ def reshape_kernel(input_tensor, target_shape):
780801 "mlp-wi_1-kernel" ,
781802 "mlp-wo-kernel" ,
782803 ]
804+ bias_hooks = [
805+ "self_attention-query-bias" ,
806+ "self_attention-key-bias" ,
807+ "self_attention-value-bias" ,
808+ ]
783809 moe_kernel_hooks = [
784810 "moe_block-gate-kernel" ,
785811 "moe_block-wi_0-kernel" ,
@@ -793,13 +819,17 @@ def reshape_kernel(input_tensor, target_shape):
793819 if scan_layers :
794820 for key in kernel_hooks :
795821 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
822+ for key in bias_hooks :
823+ mapping [f"params-decoder-layers-{ key } " ] = reshape_bias
796824 if num_experts > 1 :
797825 for key in moe_kernel_hooks :
798826 mapping [f"params-decoder-layers-{ key } " ] = reshape_kernel
799827 else :
800828 for i in range (n_layers ):
801829 for key in kernel_hooks :
802830 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
831+ for key in bias_hooks :
832+ mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_bias
803833 if num_experts > 1 :
804834 for key in moe_kernel_hooks :
805835 mapping [f"params-decoder-layers_{ i } -{ key } " ] = reshape_kernel
@@ -1376,7 +1406,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
13761406 # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
13771407 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
13781408 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1379- text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (
1409+ text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING (
13801410 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
13811411 maxtext_config = maxtext_config ,
13821412 scan_layers = scan_layers ,
@@ -1544,7 +1574,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
15441574 # Text hooks, reusing QWEN3-MOE hook function
15451575 num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
15461576 n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
1547- text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (
1577+ text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN (
15481578 config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
15491579 maxtext_config = maxtext_config ,
15501580 scan_layers = scan_layers ,
@@ -2332,24 +2362,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23322362 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23332363 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
23342364 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2335- "qwen3 -0.6b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2336- "qwen3 -1.7b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2337- "qwen3-1.7b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2338- "qwen3-4b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2339- "qwen3-4b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2340- "qwen3-4b-thinking-2507 " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2341- "qwen3-8b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2342- "qwen3-8b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2343- "qwen3-14b " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2344- "qwen3-14b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2345- "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2365+ "qwen2.5 -0.5b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2366+ "qwen2.5 -1.5b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2367+ "qwen2.5-3b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2368+ "qwen2.5-7b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2369+ "qwen2.5-14b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2370+ "qwen3-0.6b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2371+ "qwen3-4b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2372+ "qwen3-4b-thinking-2507 " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2373+ "qwen3-8b " : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2374+ "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2375+ "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
23462376 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23472377 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
23482378 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING ,
2349- "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2350- "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2351- "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2352- "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
2379+ "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2380+ "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
2381+ "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_MAPPING ,
23532382 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING ,
23542383 "gpt-oss-20b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
23552384 "gpt-oss-120b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
@@ -2370,24 +2399,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
23702399 "gemma3-4b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23712400 "gemma3-12b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23722401 "gemma3-27b" : GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2373- "qwen3 -0.6b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2374- "qwen3 -1.7b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2375- "qwen3-1.7b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2376- "qwen3-4b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2377- "qwen3-4b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2378- "qwen3-4b-thinking-2507 " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2379- "qwen3-8b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2380- "qwen3-8b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2381- "qwen3-14b " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2382- "qwen3-14b-base " : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2383- "qwen3-32b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2402+ "qwen2.5 -0.5b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2403+ "qwen2.5 -1.5b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2404+ "qwen2.5-3b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2405+ "qwen2.5-7b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2406+ "qwen2.5-14b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2407+ "qwen3-0.6b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2408+ "qwen3-4b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2409+ "qwen3-4b-thinking-2507 " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2410+ "qwen3-8b " : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2411+ "qwen3-14b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2412+ "qwen3-32b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23842413 "llama3.1-8b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23852414 "llama3.1-70b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23862415 "llama3.1-405b" : LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2387- "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2388- "qwen3-30b-a3b-base" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2389- "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2390- "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2416+ "qwen3-30b-a3b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2417+ "qwen3-235b-a22b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
2418+ "qwen3-coder-480b-a35b" : QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23912419 "deepseek3-671b" : DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
23922420 "gpt-oss-20b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
23932421 "gpt-oss-120b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
0 commit comments