@@ -1424,6 +1424,158 @@ def transform_query_kernel(arr):
14241424 return hook_fns
14251425
14261426
1427+ def MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
1428+ """
1429+ Returns the mapping of parameter names from MaxText to Hugging Face for Mixtral.
1430+ """
1431+ mapping = {}
1432+
1433+ # Top-level, non-layer-specific parameters
1434+ mapping ["params-token_embedder-embedding" ] = "model.embed_tokens.weight"
1435+ mapping ["params-decoder-decoder_norm-scale" ] = "model.norm.weight"
1436+ mapping ["params-decoder-logits_dense-kernel" ] = "lm_head.weight"
1437+
1438+ num_experts = maxtext_config .num_experts
1439+
1440+ if scan_layers :
1441+ # Initialize lists for scanned layer weights
1442+ mapping .update (
1443+ {
1444+ "params-decoder-layers-self_attention-query-kernel" : [],
1445+ "params-decoder-layers-self_attention-key-kernel" : [],
1446+ "params-decoder-layers-self_attention-value-kernel" : [],
1447+ "params-decoder-layers-self_attention-out-kernel" : [],
1448+ "params-decoder-layers-pre_self_attention_layer_norm-scale" : [],
1449+ "params-decoder-layers-post_self_attention_layer_norm-scale" : [],
1450+ "params-decoder-layers-MoeBlock_0-gate-kernel" : [],
1451+ "params-decoder-layers-MoeBlock_0-wi_0" : [],
1452+ "params-decoder-layers-MoeBlock_0-wi_1" : [],
1453+ "params-decoder-layers-MoeBlock_0-wo" : [],
1454+ }
1455+ )
1456+
1457+ for i in range (config ["num_hidden_layers" ]):
1458+ hf_prefix = f"model.layers.{ i } "
1459+ # Attention weights
1460+ mapping ["params-decoder-layers-self_attention-query-kernel" ].append (f"{ hf_prefix } .self_attn.q_proj.weight" )
1461+ mapping ["params-decoder-layers-self_attention-key-kernel" ].append (f"{ hf_prefix } .self_attn.k_proj.weight" )
1462+ mapping ["params-decoder-layers-self_attention-value-kernel" ].append (f"{ hf_prefix } .self_attn.v_proj.weight" )
1463+ mapping ["params-decoder-layers-self_attention-out-kernel" ].append (f"{ hf_prefix } .self_attn.o_proj.weight" )
1464+
1465+ # RMSNorm weights
1466+ mapping ["params-decoder-layers-pre_self_attention_layer_norm-scale" ].append (f"{ hf_prefix } .input_layernorm.weight" )
1467+ mapping ["params-decoder-layers-post_self_attention_layer_norm-scale" ].append (
1468+ f"{ hf_prefix } .post_attention_layernorm.weight"
1469+ )
1470+
1471+ # MoE gate
1472+ mapping ["params-decoder-layers-MoeBlock_0-gate-kernel" ].append (f"{ hf_prefix } .block_sparse_moe.gate.weight" )
1473+
1474+ # Outer loop as experts and inner loop as layers to align with logic in _build_multi_axis_stacked_tensor()
1475+ for j in range (num_experts ):
1476+ w1_layers = []
1477+ w3_layers = []
1478+ w2_layers = []
1479+
1480+ for i in range (config ["num_hidden_layers" ]):
1481+ hf_prefix = f"model.layers.{ i } "
1482+ w1_layers .append (f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w1.weight" )
1483+ w3_layers .append (f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w3.weight" )
1484+ w2_layers .append (f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w2.weight" )
1485+
1486+ mapping ["params-decoder-layers-MoeBlock_0-wi_0" ].append (w1_layers )
1487+ mapping ["params-decoder-layers-MoeBlock_0-wi_1" ].append (w3_layers )
1488+ mapping ["params-decoder-layers-MoeBlock_0-wo" ].append (w2_layers )
1489+
1490+ else :
1491+ for i in range (config ["num_hidden_layers" ]):
1492+ maxtext_prefix = f"params-decoder-layers_{ i } "
1493+ hf_prefix = f"model.layers.{ i } "
1494+
1495+ # Attention weights
1496+ mapping [f"{ maxtext_prefix } -self_attention-query-kernel" ] = f"{ hf_prefix } .self_attn.q_proj.weight"
1497+ mapping [f"{ maxtext_prefix } -self_attention-key-kernel" ] = f"{ hf_prefix } .self_attn.k_proj.weight"
1498+ mapping [f"{ maxtext_prefix } -self_attention-value-kernel" ] = f"{ hf_prefix } .self_attn.v_proj.weight"
1499+ mapping [f"{ maxtext_prefix } -self_attention-out-kernel" ] = f"{ hf_prefix } .self_attn.o_proj.weight"
1500+
1501+ # RMSNorm weights
1502+ mapping [f"{ maxtext_prefix } -pre_self_attention_layer_norm-scale" ] = f"{ hf_prefix } .input_layernorm.weight"
1503+ mapping [f"{ maxtext_prefix } -post_self_attention_layer_norm-scale" ] = f"{ hf_prefix } .post_attention_layernorm.weight"
1504+
1505+ # MoE gate
1506+ mapping [f"{ maxtext_prefix } -MoeBlock_0-gate-kernel" ] = f"{ hf_prefix } .block_sparse_moe.gate.weight"
1507+
1508+ # MoE expert weights (1 MaxText param -> 8 HF params)
1509+ w1_experts = [f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w1.weight" for j in range (num_experts )]
1510+ w3_experts = [f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w3.weight" for j in range (num_experts )]
1511+ w2_experts = [f"{ hf_prefix } .block_sparse_moe.experts.{ j } .w2.weight" for j in range (num_experts )]
1512+
1513+ mapping [f"{ maxtext_prefix } -MoeBlock_0-wi_0" ] = w1_experts
1514+ mapping [f"{ maxtext_prefix } -MoeBlock_0-wi_1" ] = w3_experts
1515+ mapping [f"{ maxtext_prefix } -MoeBlock_0-wo" ] = w2_experts
1516+
1517+ return mapping
1518+
1519+
1520+ def MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
1521+ """
1522+ Generates parameter conversion hooks for Mixtral between MaxText and Hugging Face.
1523+ """
1524+ hooks = {}
1525+
1526+ def reshape_and_transpose_attention (x , target_shape ):
1527+ """MaxText: [hidden, n_heads, h_dim] <-> HF: [n_heads * h_dim, hidden]"""
1528+ if saving_to_hf :
1529+ # (H, N, D) -> (H, N*D) -> (N*D, H)
1530+ return x .reshape (config ["hidden_size" ], - 1 ).transpose ()
1531+ else :
1532+ # (N*D, H) -> (H, N*D) -> (H, N, D)
1533+ return x .transpose ().reshape (target_shape )
1534+
1535+ def reshape_kernel (x , target_shape ):
1536+ return x .transpose ()
1537+
1538+ def scale_query_layer (input_tensor , target_shape ):
1539+ if saving_to_hf :
1540+ depth_scale = np .dtype ("float32" ).type (np .sqrt (maxtext_config .head_dim ))
1541+ return (input_tensor * depth_scale ).astype (input_tensor .dtype )
1542+ else :
1543+ depth_scale = np .dtype ("float32" ).type (1 / np .sqrt (maxtext_config .head_dim ))
1544+ return (input_tensor * depth_scale ).astype (input_tensor .dtype )
1545+
1546+ if scan_layers :
1547+ plan = [
1548+ ("params-decoder-layers-self_attention-query-kernel" , [reshape_and_transpose_attention , scale_query_layer ]),
1549+ ("params-decoder-layers-self_attention-key-kernel" , reshape_and_transpose_attention ),
1550+ ("params-decoder-layers-self_attention-value-kernel" , reshape_and_transpose_attention ),
1551+ ("params-decoder-layers-self_attention-out-kernel" , reshape_and_transpose_attention ),
1552+ ("params-decoder-layers-MoeBlock_0-wi_0" , reshape_kernel ),
1553+ ("params-decoder-layers-MoeBlock_0-wi_1" , reshape_kernel ),
1554+ ("params-decoder-layers-MoeBlock_0-wo" , reshape_kernel ),
1555+ ("params-decoder-layers-MoeBlock_0-gate-kernel" , reshape_kernel ),
1556+ ]
1557+ else :
1558+ plan = [
1559+ ("params-decoder-layers_{i}-self_attention-query-kernel" , [reshape_and_transpose_attention , scale_query_layer ]),
1560+ ("params-decoder-layers_{i}-self_attention-key-kernel" , reshape_and_transpose_attention ),
1561+ ("params-decoder-layers_{i}-self_attention-value-kernel" , reshape_and_transpose_attention ),
1562+ ("params-decoder-layers_{i}-self_attention-out-kernel" , reshape_and_transpose_attention ),
1563+ ("params-decoder-layers_{i}-MoeBlock_0-wi_0" , reshape_kernel ),
1564+ ("params-decoder-layers_{i}-MoeBlock_0-wi_1" , reshape_kernel ),
1565+ ("params-decoder-layers_{i}-MoeBlock_0-wo" , reshape_kernel ),
1566+ ("params-decoder-layers_{i}-MoeBlock_0-gate-kernel" , reshape_kernel ),
1567+ ]
1568+ plan .append (("params-decoder-logits_dense-kernel" , reshape_kernel ))
1569+
1570+ for maxtext_pattern , op_func in plan :
1571+ if "{i}" in maxtext_pattern :
1572+ for i in range (config ["num_hidden_layers" ]):
1573+ hooks [maxtext_pattern .format (i = i )] = op_func
1574+ else :
1575+ hooks [maxtext_pattern ] = op_func
1576+ return hooks
1577+
1578+
14271579# {maxtext model name: {maxtext weight name: hf weight name}}
14281580PARAM_MAPPING = {
14291581 "gemma2-2b" : GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING ,
@@ -1448,6 +1600,8 @@ def transform_query_kernel(arr):
14481600 "gpt-oss-20b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
14491601 "gpt-oss-120b" : GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING ,
14501602 "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING ,
1603+ "mixtral-8x7b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING ,
1604+ "mixtral-8x22b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING ,
14511605}
14521606
14531607# {maxtext model name: {maxtext weight name: bi-directional transform}}
@@ -1474,6 +1628,8 @@ def transform_query_kernel(arr):
14741628 "gpt-oss-20b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
14751629 "gpt-oss-120b" : GPT_OSS_TO_HF_PARAM_HOOK_FN ,
14761630 "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
1631+ "mixtral-8x7b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
1632+ "mixtral-8x22b" : MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
14771633}
14781634
14791635VLLM_HOOK_FNS = {
0 commit comments