Skip to content

Commit 2b3a5f1

Browse files
Merge pull request #2973 from AI-Hypercomputer:rbierneni-qwen3next-chkpt-util
PiperOrigin-RevId: 874191942
2 parents 2ade63f + f75c7e2 commit 2b3a5f1

5 files changed

Lines changed: 380 additions & 0 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The following models are supported:
1616
| **Mixtral** | 8x7B, 8x22B |||||
1717
| **GPT-OSS** | 20B, 120B |||||
1818
| **DeepSeek3** | 671B | - | - || - |
19+
| **Qwen3 Next** | 80B |||||
1920

2021
## Prerequisites
2122

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,49 @@
701701
},
702702
)
703703

704+
qwen3_next_80b_a3b_dict = {
705+
"architectures": ["Qwen3NextForCausalLM"],
706+
"attention_dropout": 0.0,
707+
"bos_token_id": 151643,
708+
"decoder_sparse_step": 1,
709+
"eos_token_id": 151645,
710+
"full_attention_interval": 4,
711+
"head_dim": 256,
712+
"hidden_act": "silu",
713+
"hidden_size": 2048,
714+
"initializer_range": 0.02,
715+
"intermediate_size": 5120,
716+
"linear_conv_kernel_dim": 4,
717+
"linear_key_head_dim": 128,
718+
"linear_num_key_heads": 16,
719+
"linear_num_value_heads": 32,
720+
"linear_value_head_dim": 128,
721+
"max_position_embeddings": 262144,
722+
"mlp_only_layers": [],
723+
"model_type": "qwen3_next",
724+
"moe_intermediate_size": 512,
725+
"norm_topk_prob": True,
726+
"num_attention_heads": 16,
727+
"num_experts": 512,
728+
"num_experts_per_tok": 10,
729+
"num_hidden_layers": 48,
730+
"num_key_value_heads": 2,
731+
"output_router_logits": False,
732+
"partial_rotary_factor": 0.25,
733+
"rms_norm_eps": 1e-06,
734+
"rope_scaling": None,
735+
"rope_theta": 10000000,
736+
"router_aux_loss_coef": 0.001,
737+
"shared_expert_intermediate_size": 512,
738+
"tie_word_embeddings": False,
739+
"torch_dtype": "bfloat16",
740+
"transformers_version": "4.57.0.dev0",
741+
"use_cache": True,
742+
"use_sliding_window": False,
743+
"vocab_size": 151936,
744+
}
745+
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)
746+
704747

705748
# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
706749
mixtral_8x7b_dict = {
@@ -789,6 +832,7 @@
789832
"gpt-oss-20b": gpt_oss_20b_config,
790833
"gpt-oss-120b": gpt_oss_120b_config,
791834
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
835+
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
792836
"mixtral-8x7b": mixtral_8x7b_config,
793837
"mixtral-8x22b": mixtral_8x22b_config,
794838
}

src/maxtext/checkpoint_conversion/utils/hf_shape.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,102 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
349349
return mapping
350350

351351

352+
def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
353+
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
354+
# --- Extract Core Config Values ---
355+
hidden_size = config["hidden_size"]
356+
num_hidden_layers = config["num_hidden_layers"]
357+
vocab_size = config["vocab_size"]
358+
num_attention_heads = config["num_attention_heads"]
359+
num_key_value_heads = config["num_key_value_heads"]
360+
num_experts = config["num_experts"]
361+
head_dim = config["head_dim"]
362+
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
363+
linear_key_head_dim = config["linear_key_head_dim"]
364+
linear_num_key_heads = config["linear_num_key_heads"]
365+
linear_num_value_heads = config["linear_num_value_heads"]
366+
moe_intermediate_size = config["moe_intermediate_size"]
367+
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
368+
cycle_interval = config["full_attention_interval"]
369+
370+
# --- Calculated Values ---
371+
q_dim = num_attention_heads * head_dim
372+
kv_dim = num_key_value_heads * head_dim
373+
374+
linear_k_dim = linear_num_key_heads * linear_key_head_dim
375+
linear_v_dim = linear_num_value_heads * head_dim
376+
conv_dim = 2 * linear_k_dim + linear_v_dim
377+
qkvz_dim = 2 * linear_k_dim + 2 * linear_v_dim
378+
ba_dim = 2 * linear_num_value_heads
379+
380+
# --- Initialize Mapping ---
381+
mapping = {
382+
"model.embed_tokens.weight": [vocab_size, hidden_size],
383+
"model.norm.weight": [hidden_size],
384+
"lm_head.weight": [vocab_size, hidden_size],
385+
}
386+
387+
for layer_idx in range(num_hidden_layers):
388+
layer_prefix = f"model.layers.{layer_idx}"
389+
390+
# Standard Layer Norms
391+
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
392+
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]
393+
394+
is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0
395+
396+
if is_full_attention_layer:
397+
# Full Attention Block
398+
mapping.update(
399+
{
400+
f"{layer_prefix}.self_attn.q_proj.weight": [2 * q_dim, hidden_size],
401+
f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size],
402+
f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size],
403+
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim],
404+
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
405+
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
406+
}
407+
)
408+
else:
409+
# Linear Attention (GDN) Block
410+
mapping.update(
411+
{
412+
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [qkvz_dim, hidden_size],
413+
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [ba_dim, hidden_size],
414+
f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim],
415+
f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads],
416+
f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads],
417+
f"{layer_prefix}.linear_attn.norm.weight": [head_dim],
418+
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim],
419+
}
420+
)
421+
422+
# --- MLP Logic (MoE + Shared) ---
423+
mapping.update(
424+
{
425+
# Router
426+
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],
427+
# Shared Experts (SwiGLU - Separate Weights)
428+
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
429+
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
430+
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],
431+
# Shared Expert Gate (learned scaling factor)
432+
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
433+
}
434+
)
435+
436+
# Routed Experts Loop
437+
# Note: HF typically stores experts as a ModuleList
438+
for e in range(num_experts):
439+
mapping.update(
440+
{
441+
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
442+
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
443+
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
444+
}
445+
)
446+
447+
352448
def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
353449
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
354450
# --- Extract Core Config Values ---

0 commit comments

Comments
 (0)