Skip to content

Commit f75c7e2

Browse files
Add Qwen3-Next to checkpoint util (Squashed)
Lower the decoding length to 128 oinstead of 512 update test script Add manual calcs for each hf shape instead of hardcoded values Run pylint Fix config values and remove unused vars Fix /configs path in script Fix decode path in script Update model ReadMe Update scripts to use new train path Reset qwen3 test files to match main Undo the temp fix to get training working Update reshape function to what other models use remove whitespaces
1 parent 40071fc commit f75c7e2

5 files changed

Lines changed: 380 additions & 0 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The following models are supported:
1616
| **Mixtral** | 8x7B, 8x22B |||||
1717
| **GPT-OSS** | 20B, 120B |||||
1818
| **DeepSeek3** | 671B | - | - || - |
19+
| **Qwen3 Next** | 80B |||||
1920

2021
## Prerequisites
2122

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,49 @@
701701
},
702702
)
703703

704+
qwen3_next_80b_a3b_dict = {
705+
"architectures": ["Qwen3NextForCausalLM"],
706+
"attention_dropout": 0.0,
707+
"bos_token_id": 151643,
708+
"decoder_sparse_step": 1,
709+
"eos_token_id": 151645,
710+
"full_attention_interval": 4,
711+
"head_dim": 256,
712+
"hidden_act": "silu",
713+
"hidden_size": 2048,
714+
"initializer_range": 0.02,
715+
"intermediate_size": 5120,
716+
"linear_conv_kernel_dim": 4,
717+
"linear_key_head_dim": 128,
718+
"linear_num_key_heads": 16,
719+
"linear_num_value_heads": 32,
720+
"linear_value_head_dim": 128,
721+
"max_position_embeddings": 262144,
722+
"mlp_only_layers": [],
723+
"model_type": "qwen3_next",
724+
"moe_intermediate_size": 512,
725+
"norm_topk_prob": True,
726+
"num_attention_heads": 16,
727+
"num_experts": 512,
728+
"num_experts_per_tok": 10,
729+
"num_hidden_layers": 48,
730+
"num_key_value_heads": 2,
731+
"output_router_logits": False,
732+
"partial_rotary_factor": 0.25,
733+
"rms_norm_eps": 1e-06,
734+
"rope_scaling": None,
735+
"rope_theta": 10000000,
736+
"router_aux_loss_coef": 0.001,
737+
"shared_expert_intermediate_size": 512,
738+
"tie_word_embeddings": False,
739+
"torch_dtype": "bfloat16",
740+
"transformers_version": "4.57.0.dev0",
741+
"use_cache": True,
742+
"use_sliding_window": False,
743+
"vocab_size": 151936,
744+
}
745+
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)
746+
704747

705748
# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
706749
mixtral_8x7b_dict = {
@@ -789,6 +832,7 @@
789832
"gpt-oss-20b": gpt_oss_20b_config,
790833
"gpt-oss-120b": gpt_oss_120b_config,
791834
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
835+
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
792836
"mixtral-8x7b": mixtral_8x7b_config,
793837
"mixtral-8x22b": mixtral_8x22b_config,
794838
}

src/maxtext/checkpoint_conversion/utils/hf_shape.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,102 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
349349
return mapping
350350

351351

352+
def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
353+
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
354+
# --- Extract Core Config Values ---
355+
hidden_size = config["hidden_size"]
356+
num_hidden_layers = config["num_hidden_layers"]
357+
vocab_size = config["vocab_size"]
358+
num_attention_heads = config["num_attention_heads"]
359+
num_key_value_heads = config["num_key_value_heads"]
360+
num_experts = config["num_experts"]
361+
head_dim = config["head_dim"]
362+
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
363+
linear_key_head_dim = config["linear_key_head_dim"]
364+
linear_num_key_heads = config["linear_num_key_heads"]
365+
linear_num_value_heads = config["linear_num_value_heads"]
366+
moe_intermediate_size = config["moe_intermediate_size"]
367+
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
368+
cycle_interval = config["full_attention_interval"]
369+
370+
# --- Calculated Values ---
371+
q_dim = num_attention_heads * head_dim
372+
kv_dim = num_key_value_heads * head_dim
373+
374+
linear_k_dim = linear_num_key_heads * linear_key_head_dim
375+
linear_v_dim = linear_num_value_heads * head_dim
376+
conv_dim = 2 * linear_k_dim + linear_v_dim
377+
qkvz_dim = 2 * linear_k_dim + 2 * linear_v_dim
378+
ba_dim = 2 * linear_num_value_heads
379+
380+
# --- Initialize Mapping ---
381+
mapping = {
382+
"model.embed_tokens.weight": [vocab_size, hidden_size],
383+
"model.norm.weight": [hidden_size],
384+
"lm_head.weight": [vocab_size, hidden_size],
385+
}
386+
387+
for layer_idx in range(num_hidden_layers):
388+
layer_prefix = f"model.layers.{layer_idx}"
389+
390+
# Standard Layer Norms
391+
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
392+
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]
393+
394+
is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0
395+
396+
if is_full_attention_layer:
397+
# Full Attention Block
398+
mapping.update(
399+
{
400+
f"{layer_prefix}.self_attn.q_proj.weight": [2 * q_dim, hidden_size],
401+
f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size],
402+
f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size],
403+
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim],
404+
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
405+
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
406+
}
407+
)
408+
else:
409+
# Linear Attention (GDN) Block
410+
mapping.update(
411+
{
412+
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [qkvz_dim, hidden_size],
413+
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [ba_dim, hidden_size],
414+
f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim],
415+
f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads],
416+
f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads],
417+
f"{layer_prefix}.linear_attn.norm.weight": [head_dim],
418+
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim],
419+
}
420+
)
421+
422+
# --- MLP Logic (MoE + Shared) ---
423+
mapping.update(
424+
{
425+
# Router
426+
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],
427+
# Shared Experts (SwiGLU - Separate Weights)
428+
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
429+
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
430+
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],
431+
# Shared Expert Gate (learned scaling factor)
432+
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
433+
}
434+
)
435+
436+
# Routed Experts Loop
437+
# Note: HF typically stores experts as a ModuleList
438+
for e in range(num_experts):
439+
mapping.update(
440+
{
441+
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
442+
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
443+
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
444+
}
445+
)
446+
447+
352448
def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
353449
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
354450
# --- Extract Core Config Values ---

0 commit comments

Comments
 (0)