Skip to content

Commit f8aeead

Browse files
Merge pull request #2870 from AI-Hypercomputer:mixtral_clean
PiperOrigin-RevId: 848247556
2 parents e61f6fa + cca9ac1 commit f8aeead

7 files changed

Lines changed: 305 additions & 4 deletions

File tree

src/MaxText/experimental/agent/ckpt_conversion_agent/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Checkpoint conversion agent
2-
The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/MaxText/utils/ckpt_conversion).
2+
The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion).
33

44
## Quick starts
55
To begin, you'll need:
@@ -16,7 +16,7 @@ pip install -q -U "google-genai>=1.0.0"
1616

1717
## 1. Prepare the context file
1818

19-
The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](ckpt_conversion/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/MaxText/experimental/agent/ckpt_conversion_agent/context/<model_name>` folder.
19+
The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/MaxText/experimental/agent/ckpt_conversion_agent/context/<model_name>` folder.
2020
```bash
2121
python3 -m MaxText.experimental.agent.ckpt_conversion_agent.utils.save_param src/MaxText/configs/base.yml \
2222
per_device_batch_size=1 run_name=param_<model_name> model_name=<model_name> scan_layers=false \

src/MaxText/utils/ckpt_conversion/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The following models are supported:
99
- Gemma2 (2B, 9B, 27B).
1010
- Gemma3 multimodal (4B, 12B, 27B).
1111
- Qwen3 (0.6B, 4B, 8B, 14B, 32B).
12+
- Mixtral (8x7B, 8x22B).
1213

1314
## Prerequisites
1415
- Hugging Face requires Pytorch.

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,69 @@
691691
},
692692
)
693693

694+
695+
# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
696+
mixtral_8x7b_dict = {
697+
"architectures": ["MixtralForCausalLM"],
698+
"attention_dropout": 0.0,
699+
"bos_token_id": 1,
700+
"eos_token_id": 2,
701+
"hidden_act": "silu",
702+
"hidden_size": 4096,
703+
"initializer_range": 0.02,
704+
"intermediate_size": 14336,
705+
"max_position_embeddings": 32768,
706+
"model_type": "mixtral",
707+
"num_attention_heads": 32,
708+
"num_experts_per_tok": 2,
709+
"num_hidden_layers": 32,
710+
"num_key_value_heads": 8,
711+
"num_local_experts": 8,
712+
"output_router_logits": False,
713+
"rms_norm_eps": 1e-05,
714+
"rope_theta": 1000000.0,
715+
"router_aux_loss_coef": 0.02,
716+
"sliding_window": None,
717+
"tie_word_embeddings": False,
718+
"torch_dtype": "bfloat16",
719+
"transformers_version": "4.36.0.dev0",
720+
"use_cache": True,
721+
"vocab_size": 32000,
722+
}
723+
mixtral_8x7b_config = transformers.MixtralConfig(**mixtral_8x7b_dict)
724+
725+
726+
# from https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json
727+
mixtral_8x22b_dict = {
728+
"architectures": ["MixtralForCausalLM"],
729+
"attention_dropout": 0.0,
730+
"bos_token_id": 1,
731+
"eos_token_id": 2,
732+
"hidden_act": "silu",
733+
"hidden_size": 6144,
734+
"initializer_range": 0.02,
735+
"intermediate_size": 16384,
736+
"max_position_embeddings": 65536,
737+
"model_type": "mixtral",
738+
"num_attention_heads": 48,
739+
"num_experts_per_tok": 2,
740+
"num_hidden_layers": 56,
741+
"num_key_value_heads": 8,
742+
"num_local_experts": 8,
743+
"output_router_logits": False,
744+
"rms_norm_eps": 1e-05,
745+
"rope_theta": 1000000.0,
746+
"router_aux_loss_coef": 0.001,
747+
"sliding_window": None,
748+
"tie_word_embeddings": False,
749+
"torch_dtype": "bfloat16",
750+
"transformers_version": "4.38.0",
751+
"use_cache": True,
752+
"vocab_size": 32768,
753+
}
754+
mixtral_8x22b_config = transformers.MixtralConfig(**mixtral_8x22b_dict)
755+
756+
694757
# {maxtext model name: hf model config}
695758
HF_MODEL_CONFIGS = {
696759
"gemma2-2b": gemma2_2b_config,
@@ -716,4 +779,6 @@
716779
"gpt-oss-20b": gpt_oss_20b_config,
717780
"gpt-oss-120b": gpt_oss_120b_config,
718781
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
782+
"mixtral-8x7b": mixtral_8x7b_config,
783+
"mixtral-8x22b": mixtral_8x22b_config,
719784
}

src/MaxText/utils/ckpt_conversion/utils/hf_shape.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,77 @@ def LLAMA31_HF_WEIGHTS_TO_SHAPE(config):
581581
return mapping
582582

583583

584+
def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
585+
"""
586+
Returns a mapping of Hugging Face parameter names to their tensor shapes.
587+
588+
Args:
589+
config (dict): The model configuration dictionary.
590+
591+
Returns:
592+
A dictionary mapping Hugging Face parameter paths to their tensor shapes.
593+
"""
594+
shapes = {}
595+
596+
# Embedding and LM Head
597+
shapes["model.embed_tokens.weight"] = [config["vocab_size"], config["hidden_size"]]
598+
shapes["lm_head.weight"] = [config["vocab_size"], config["hidden_size"]]
599+
600+
# Final LayerNorm
601+
shapes["model.norm.weight"] = [config["hidden_size"]]
602+
603+
# Calculated dimensions
604+
head_dim = config["hidden_size"] // config["num_attention_heads"]
605+
kv_dim = config["num_key_value_heads"] * head_dim
606+
607+
# Decoder Layers
608+
for i in range(config["num_hidden_layers"]):
609+
# Attention Projections
610+
shapes[f"model.layers.{i}.self_attn.q_proj.weight"] = [
611+
config["hidden_size"],
612+
config["hidden_size"],
613+
]
614+
shapes[f"model.layers.{i}.self_attn.k_proj.weight"] = [
615+
kv_dim,
616+
config["hidden_size"],
617+
]
618+
shapes[f"model.layers.{i}.self_attn.v_proj.weight"] = [
619+
kv_dim,
620+
config["hidden_size"],
621+
]
622+
shapes[f"model.layers.{i}.self_attn.o_proj.weight"] = [
623+
config["hidden_size"],
624+
config["hidden_size"],
625+
]
626+
627+
# LayerNorms
628+
shapes[f"model.layers.{i}.input_layernorm.weight"] = [config["hidden_size"]]
629+
shapes[f"model.layers.{i}.post_attention_layernorm.weight"] = [config["hidden_size"]]
630+
631+
# MOE Gate
632+
shapes[f"model.layers.{i}.block_sparse_moe.gate.weight"] = [
633+
config["num_local_experts"],
634+
config["hidden_size"],
635+
]
636+
637+
# MOE Experts
638+
for j in range(config["num_local_experts"]):
639+
shapes[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"] = [
640+
config["intermediate_size"],
641+
config["hidden_size"],
642+
]
643+
shapes[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"] = [
644+
config["hidden_size"],
645+
config["intermediate_size"],
646+
]
647+
shapes[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"] = [
648+
config["intermediate_size"],
649+
config["hidden_size"],
650+
]
651+
652+
return shapes
653+
654+
584655
# {maxtext model name: {hf weight name: hf shape}}
585656
HF_SHAPE = {
586657
"gemma2-2b": GEMMA2_HF_WEIGHTS_TO_SHAPE,
@@ -604,4 +675,6 @@ def LLAMA31_HF_WEIGHTS_TO_SHAPE(config):
604675
"deepseek3-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE,
605676
"gpt-oss-20b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,
606677
"gpt-oss-120b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,
678+
"mixtral-8x7b": MIXTRAL_HF_WEIGHTS_TO_SHAPE,
679+
"mixtral-8x22b": MIXTRAL_HF_WEIGHTS_TO_SHAPE,
607680
}

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,158 @@ def transform_query_kernel(arr):
14241424
return hook_fns
14251425

14261426

1427+
def MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
1428+
"""
1429+
Returns the mapping of parameter names from MaxText to Hugging Face for Mixtral.
1430+
"""
1431+
mapping = {}
1432+
1433+
# Top-level, non-layer-specific parameters
1434+
mapping["params-token_embedder-embedding"] = "model.embed_tokens.weight"
1435+
mapping["params-decoder-decoder_norm-scale"] = "model.norm.weight"
1436+
mapping["params-decoder-logits_dense-kernel"] = "lm_head.weight"
1437+
1438+
num_experts = maxtext_config.num_experts
1439+
1440+
if scan_layers:
1441+
# Initialize lists for scanned layer weights
1442+
mapping.update(
1443+
{
1444+
"params-decoder-layers-self_attention-query-kernel": [],
1445+
"params-decoder-layers-self_attention-key-kernel": [],
1446+
"params-decoder-layers-self_attention-value-kernel": [],
1447+
"params-decoder-layers-self_attention-out-kernel": [],
1448+
"params-decoder-layers-pre_self_attention_layer_norm-scale": [],
1449+
"params-decoder-layers-post_self_attention_layer_norm-scale": [],
1450+
"params-decoder-layers-MoeBlock_0-gate-kernel": [],
1451+
"params-decoder-layers-MoeBlock_0-wi_0": [],
1452+
"params-decoder-layers-MoeBlock_0-wi_1": [],
1453+
"params-decoder-layers-MoeBlock_0-wo": [],
1454+
}
1455+
)
1456+
1457+
for i in range(config["num_hidden_layers"]):
1458+
hf_prefix = f"model.layers.{i}"
1459+
# Attention weights
1460+
mapping["params-decoder-layers-self_attention-query-kernel"].append(f"{hf_prefix}.self_attn.q_proj.weight")
1461+
mapping["params-decoder-layers-self_attention-key-kernel"].append(f"{hf_prefix}.self_attn.k_proj.weight")
1462+
mapping["params-decoder-layers-self_attention-value-kernel"].append(f"{hf_prefix}.self_attn.v_proj.weight")
1463+
mapping["params-decoder-layers-self_attention-out-kernel"].append(f"{hf_prefix}.self_attn.o_proj.weight")
1464+
1465+
# RMSNorm weights
1466+
mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"].append(f"{hf_prefix}.input_layernorm.weight")
1467+
mapping["params-decoder-layers-post_self_attention_layer_norm-scale"].append(
1468+
f"{hf_prefix}.post_attention_layernorm.weight"
1469+
)
1470+
1471+
# MoE gate
1472+
mapping["params-decoder-layers-MoeBlock_0-gate-kernel"].append(f"{hf_prefix}.block_sparse_moe.gate.weight")
1473+
1474+
# Outer loop as experts and inner loop as layers to align with logic in _build_multi_axis_stacked_tensor()
1475+
for j in range(num_experts):
1476+
w1_layers = []
1477+
w3_layers = []
1478+
w2_layers = []
1479+
1480+
for i in range(config["num_hidden_layers"]):
1481+
hf_prefix = f"model.layers.{i}"
1482+
w1_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight")
1483+
w3_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight")
1484+
w2_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight")
1485+
1486+
mapping["params-decoder-layers-MoeBlock_0-wi_0"].append(w1_layers)
1487+
mapping["params-decoder-layers-MoeBlock_0-wi_1"].append(w3_layers)
1488+
mapping["params-decoder-layers-MoeBlock_0-wo"].append(w2_layers)
1489+
1490+
else:
1491+
for i in range(config["num_hidden_layers"]):
1492+
maxtext_prefix = f"params-decoder-layers_{i}"
1493+
hf_prefix = f"model.layers.{i}"
1494+
1495+
# Attention weights
1496+
mapping[f"{maxtext_prefix}-self_attention-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight"
1497+
mapping[f"{maxtext_prefix}-self_attention-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight"
1498+
mapping[f"{maxtext_prefix}-self_attention-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight"
1499+
mapping[f"{maxtext_prefix}-self_attention-out-kernel"] = f"{hf_prefix}.self_attn.o_proj.weight"
1500+
1501+
# RMSNorm weights
1502+
mapping[f"{maxtext_prefix}-pre_self_attention_layer_norm-scale"] = f"{hf_prefix}.input_layernorm.weight"
1503+
mapping[f"{maxtext_prefix}-post_self_attention_layer_norm-scale"] = f"{hf_prefix}.post_attention_layernorm.weight"
1504+
1505+
# MoE gate
1506+
mapping[f"{maxtext_prefix}-MoeBlock_0-gate-kernel"] = f"{hf_prefix}.block_sparse_moe.gate.weight"
1507+
1508+
# MoE expert weights (1 MaxText param -> 8 HF params)
1509+
w1_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight" for j in range(num_experts)]
1510+
w3_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight" for j in range(num_experts)]
1511+
w2_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight" for j in range(num_experts)]
1512+
1513+
mapping[f"{maxtext_prefix}-MoeBlock_0-wi_0"] = w1_experts
1514+
mapping[f"{maxtext_prefix}-MoeBlock_0-wi_1"] = w3_experts
1515+
mapping[f"{maxtext_prefix}-MoeBlock_0-wo"] = w2_experts
1516+
1517+
return mapping
1518+
1519+
1520+
def MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
1521+
"""
1522+
Generates parameter conversion hooks for Mixtral between MaxText and Hugging Face.
1523+
"""
1524+
hooks = {}
1525+
1526+
def reshape_and_transpose_attention(x, target_shape):
1527+
"""MaxText: [hidden, n_heads, h_dim] <-> HF: [n_heads * h_dim, hidden]"""
1528+
if saving_to_hf:
1529+
# (H, N, D) -> (H, N*D) -> (N*D, H)
1530+
return x.reshape(config["hidden_size"], -1).transpose()
1531+
else:
1532+
# (N*D, H) -> (H, N*D) -> (H, N, D)
1533+
return x.transpose().reshape(target_shape)
1534+
1535+
def reshape_kernel(x, target_shape):
1536+
return x.transpose()
1537+
1538+
def scale_query_layer(input_tensor, target_shape):
1539+
if saving_to_hf:
1540+
depth_scale = np.dtype("float32").type(np.sqrt(maxtext_config.head_dim))
1541+
return (input_tensor * depth_scale).astype(input_tensor.dtype)
1542+
else:
1543+
depth_scale = np.dtype("float32").type(1 / np.sqrt(maxtext_config.head_dim))
1544+
return (input_tensor * depth_scale).astype(input_tensor.dtype)
1545+
1546+
if scan_layers:
1547+
plan = [
1548+
("params-decoder-layers-self_attention-query-kernel", [reshape_and_transpose_attention, scale_query_layer]),
1549+
("params-decoder-layers-self_attention-key-kernel", reshape_and_transpose_attention),
1550+
("params-decoder-layers-self_attention-value-kernel", reshape_and_transpose_attention),
1551+
("params-decoder-layers-self_attention-out-kernel", reshape_and_transpose_attention),
1552+
("params-decoder-layers-MoeBlock_0-wi_0", reshape_kernel),
1553+
("params-decoder-layers-MoeBlock_0-wi_1", reshape_kernel),
1554+
("params-decoder-layers-MoeBlock_0-wo", reshape_kernel),
1555+
("params-decoder-layers-MoeBlock_0-gate-kernel", reshape_kernel),
1556+
]
1557+
else:
1558+
plan = [
1559+
("params-decoder-layers_{i}-self_attention-query-kernel", [reshape_and_transpose_attention, scale_query_layer]),
1560+
("params-decoder-layers_{i}-self_attention-key-kernel", reshape_and_transpose_attention),
1561+
("params-decoder-layers_{i}-self_attention-value-kernel", reshape_and_transpose_attention),
1562+
("params-decoder-layers_{i}-self_attention-out-kernel", reshape_and_transpose_attention),
1563+
("params-decoder-layers_{i}-MoeBlock_0-wi_0", reshape_kernel),
1564+
("params-decoder-layers_{i}-MoeBlock_0-wi_1", reshape_kernel),
1565+
("params-decoder-layers_{i}-MoeBlock_0-wo", reshape_kernel),
1566+
("params-decoder-layers_{i}-MoeBlock_0-gate-kernel", reshape_kernel),
1567+
]
1568+
plan.append(("params-decoder-logits_dense-kernel", reshape_kernel))
1569+
1570+
for maxtext_pattern, op_func in plan:
1571+
if "{i}" in maxtext_pattern:
1572+
for i in range(config["num_hidden_layers"]):
1573+
hooks[maxtext_pattern.format(i=i)] = op_func
1574+
else:
1575+
hooks[maxtext_pattern] = op_func
1576+
return hooks
1577+
1578+
14271579
# {maxtext model name: {maxtext weight name: hf weight name}}
14281580
PARAM_MAPPING = {
14291581
"gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
@@ -1448,6 +1600,8 @@ def transform_query_kernel(arr):
14481600
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
14491601
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
14501602
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
1603+
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
1604+
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
14511605
}
14521606

14531607
# {maxtext model name: {maxtext weight name: bi-directional transform}}
@@ -1474,6 +1628,8 @@ def transform_query_kernel(arr):
14741628
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
14751629
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
14761630
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
1631+
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
1632+
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
14771633
}
14781634

14791635
VLLM_HOOK_FNS = {

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
"gpt-oss-20b": "openai/gpt-oss-20b",
7777
"gpt-oss-120b": "openai/gpt-oss-120b",
7878
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
79+
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
80+
"mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1",
7981
}
8082

8183

@@ -195,7 +197,10 @@ def process_maxtext_param(
195197

196198
# Case 3 or 4: The source tensor is stacked on a single axis.
197199
# We determine if it's an unscanned MoE (expert axis) or standard scanned (layer axis).
198-
is_unscanned_moe = "moe_block" in maxtext_param_key and any(
200+
# `w` is needed for weights, and except for gate.
201+
# Gate values are stack in layers only, but weights are stack in both expert and layer.
202+
moe_block_list = ["moe_block", "MoeBlock_0-w"]
203+
is_unscanned_moe = any(block in maxtext_param_key for block in moe_block_list) and any(
199204
f"_{i}-" in maxtext_param_key for i in range(maxtext_config.base_num_decoder_layers)
200205
)
201206

tests/forward_pass_logit_checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ def main(config, test_args): # pylint: disable=W0621
380380
raise ValueError("run_hf_model requires hf_model_path")
381381
hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch.bfloat16)
382382
tokenizer = AutoTokenizer.from_pretrained(test_args.hf_model_path)
383-
if "Llama-3.1" in test_args.hf_model_path:
383+
pad_token_models = ["Llama-3.1", "Mixtral-8x"]
384+
if any(model in test_args.hf_model_path for model in pad_token_models):
384385
tokenizer.pad_token = tokenizer.eos_token
385386

386387
init_rng = jax.random.PRNGKey(config.init_weights_seed)

0 commit comments

Comments
 (0)