Skip to content

Commit 941d46a

Browse files
Merge pull request #2995 from AI-Hypercomputer:deepseek-moe
PiperOrigin-RevId: 861495189
2 parents aee5753 + aceba0a commit 941d46a

3 files changed

Lines changed: 319 additions & 1 deletion

File tree

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
dispatcher to retrieve the correct weight mapping configuration for a given
1919
model name. This allows for easy extension to support new models.
2020
"""
21-
21+
from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,10 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("deepseek3"):
36+
return DEEPSEEK_VLLM_MAPPING
37+
elif name.startswith("gpt-oss"):
38+
return GPT_OSS_VLLM_MAPPING
3439
else:
3540
raise ValueError(f"{name} vLLM weight mapping not found.")
3641

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys."""
16+
17+
from dataclasses import dataclass
18+
19+
20+
@dataclass
21+
class DEEPSEEK_VLLM_MAPPING:
22+
"""Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys."""
23+
24+
@staticmethod
25+
def to_hf_hook_fns():
26+
def flatten_3d_to_2d(val):
27+
# Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim)
28+
if val.ndim == 3:
29+
return val.reshape(val.shape[0], -1)
30+
return val
31+
32+
return {
33+
# MaxText MLA weights are 3D (Rank, Heads, HeadDim).
34+
# tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them.
35+
"base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d,
36+
"base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d,
37+
"base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d,
38+
}
39+
40+
@staticmethod
41+
def to_hf_transpose_keys():
42+
"""Returns a list of keys for weights that need to be transposed.
43+
44+
Returns:
45+
An empty dictionary, as no keys require transposition for this mapping.
46+
"""
47+
return {}
48+
49+
@staticmethod
50+
def lora_to_hf_mappings():
51+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
52+
53+
Returns:
54+
None, as LoRA mappings are not defined for this model.
55+
"""
56+
return None
57+
58+
@staticmethod
59+
def to_hf_mapping():
60+
"""Returns the weight mapping for the model."""
61+
mapping = {
62+
# --- Base Model Params ---
63+
# Map to HF names to be safe with loader regexes
64+
"base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)),
65+
"base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)),
66+
"base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")),
67+
# MLA LAYERS (Map to HF Keys to trigger loader splitting logic)
68+
# Norms
69+
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
70+
"model.layers.*.input_layernorm.weight",
71+
(None, "layer"),
72+
),
73+
"base.decoder.layers.post_self_attention_layer_norm.scale": (
74+
"model.layers.*.post_attention_layernorm.weight",
75+
(None, "layer"),
76+
),
77+
# MLA Norms
78+
"base.decoder.layers.self_attention.kv_norm.scale": (
79+
"model.layers.*.self_attn.kv_a_layernorm.weight",
80+
(None, "layer"),
81+
),
82+
"base.decoder.layers.self_attention.q_norm.scale": (
83+
"model.layers.*.self_attn.q_a_layernorm.weight",
84+
(None, "layer"),
85+
),
86+
# MLA Projections
87+
# We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj"
88+
# and performs the necessary split into k_b and v_b for the MLA kernel.
89+
"base.decoder.layers.self_attention.wq_a.kernel": (
90+
"model.layers.*.self_attn.q_a_proj.weight",
91+
(None, "layer", "model", None),
92+
),
93+
"base.decoder.layers.self_attention.wq_b.kernel": (
94+
"model.layers.*.self_attn.q_b_proj.weight",
95+
(None, "layer", "model", None),
96+
),
97+
"base.decoder.layers.self_attention.wkv_a.kernel": (
98+
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight",
99+
(None, "layer", "model", None),
100+
),
101+
"base.decoder.layers.self_attention.wkv_b.kernel": (
102+
"model.layers.*.self_attn.kv_b_proj.weight",
103+
(None, "layer", "model", None),
104+
),
105+
"base.decoder.layers.self_attention.out.kernel": (
106+
"model.layers.*.self_attn.o_proj.weight",
107+
("model", "layer", None, None),
108+
),
109+
# DENSE MLP LAYERS (Map to vllm keys for safety/consistency)
110+
"base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")),
111+
"base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")),
112+
"base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)),
113+
# MOE LAYERS (Map to INTERNAL keys to bypass loader stacking)
114+
# Since MaxText experts are already fused/stacked, we map directly to the
115+
# internal `tpu-inference` param names. The loader will fail to find
116+
# "experts.{i}" in the name and fall back to loading these directly,
117+
# which is exactly what we want for performance.
118+
# Shared Experts
119+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": (
120+
"layers.*.shared_experts.kernel_gating_DF",
121+
(None, "layer", "model"),
122+
),
123+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": (
124+
"layers.*.shared_experts.kernel_up_proj_DF",
125+
(None, "layer", "model"),
126+
),
127+
"base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": (
128+
"layers.*.shared_experts.kernel_down_proj_FD",
129+
("model", "layer", None),
130+
),
131+
# Router
132+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": (
133+
"layers.*.custom_module.router.kernel_DE",
134+
(None, "layer", "model"),
135+
),
136+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": (
137+
"layers.*.custom_module.router.bias_E",
138+
(None, "layer", "model"),
139+
),
140+
# Routed Experts (Fused)
141+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": (
142+
"layers.*.custom_module.kernel_gating_EDF",
143+
("expert", "layer", None, "model"),
144+
),
145+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": (
146+
"layers.*.custom_module.kernel_up_proj_EDF",
147+
("expert", "layer", None, "model"),
148+
),
149+
"base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": (
150+
"layers.*.custom_module.kernel_down_proj_EFD",
151+
("expert", "layer", "model", None),
152+
),
153+
# MTP BLOCK (Included for completeness, but typically skipped by current loader)
154+
"base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)),
155+
"base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)),
156+
"base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")),
157+
}
158+
return mapping
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys."""
16+
17+
from dataclasses import dataclass
18+
from typing import Dict, Optional, Tuple
19+
20+
21+
@dataclass
22+
class GPT_OSS_VLLM_MAPPING:
23+
"""
24+
Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX.
25+
Supports:
26+
- Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...)
27+
"""
28+
29+
@staticmethod
30+
def lora_to_hf_mappings():
31+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
32+
Returns:
33+
None, as LoRA mappings are not defined for this model.
34+
"""
35+
return None
36+
37+
@staticmethod
38+
def to_hf_hook_fns():
39+
"""Returns hook functions to fuse interleaved weights."""
40+
return {}
41+
42+
@staticmethod
43+
def to_hf_transpose_keys():
44+
"""Returns keys that need to be transposed."""
45+
return {}
46+
47+
@staticmethod
48+
def to_hf_mapping(
49+
layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo"
50+
) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]:
51+
"""Returns the weight mapping for the model.
52+
Args:
53+
layer_cycle_interval: The interval at which layers are cycled.
54+
total_num_layers: The total number of layers in the model.
55+
interleave_style: The style of interleaving used for the layers.
56+
Returns:
57+
A dictionary mapping MaxText parameter names to vLLM parameter names.
58+
"""
59+
60+
mapping = {}
61+
62+
# --- 1. Global Parameters ---
63+
mapping.update(
64+
{
65+
"base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)),
66+
"base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)),
67+
"base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")),
68+
}
69+
)
70+
71+
# --- 2. Layer Mapping Loop ---
72+
layers_per_block = total_num_layers // layer_cycle_interval
73+
74+
for block_idx in range(layer_cycle_interval):
75+
src_block = f"base.decoder.layers.layers_{block_idx}"
76+
if interleave_style == "modulo":
77+
target_indices = range(block_idx, total_num_layers, layer_cycle_interval)
78+
else:
79+
start = block_idx * layers_per_block
80+
target_indices = range(start, start + layers_per_block)
81+
82+
regex_indices = "|".join(map(str, target_indices))
83+
layer_regex = f"layers\.({regex_indices})"
84+
85+
# --- 3. Block Mappings (Standard) ---
86+
mapping.update(
87+
{
88+
f"{src_block}.pre_self_attention_layer_norm.scale": (
89+
f"{layer_regex}.pre_attention_norm.scale",
90+
(None, "layer"),
91+
),
92+
f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")),
93+
f"{src_block}.GptOssAttention.query.kernel": (
94+
f"{layer_regex}.attn.kernel_q_DNH",
95+
(None, "layer", "model", None),
96+
),
97+
f"{src_block}.GptOssAttention.key.kernel": (
98+
f"{layer_regex}.attn.kernel_k_DKH",
99+
(None, "layer", "model", None),
100+
),
101+
f"{src_block}.GptOssAttention.value.kernel": (
102+
f"{layer_regex}.attn.kernel_v_DKH",
103+
(None, "layer", "model", None),
104+
),
105+
f"{src_block}.GptOssAttention.out.kernel": (
106+
f"{layer_regex}.attn.kernel_o_proj_NHD",
107+
("model", "layer", None, None),
108+
),
109+
f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)),
110+
f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)),
111+
f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)),
112+
f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")),
113+
f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")),
114+
}
115+
)
116+
117+
# MoE Router
118+
mapping.update(
119+
{
120+
f"{src_block}.GptOssMlp.gate.kernel": (
121+
f"{layer_regex}.custom_module.router.kernel_DE",
122+
(None, "layer", "model"),
123+
),
124+
f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")),
125+
}
126+
)
127+
128+
# --- MOE EXPERTS ---
129+
# Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases.
130+
131+
# MLP Gate Projection (wi_0)
132+
mapping.update(
133+
{
134+
f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)),
135+
f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")),
136+
}
137+
)
138+
139+
# MLP Up Projection (wi_1)
140+
mapping.update(
141+
{
142+
f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)),
143+
f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")),
144+
}
145+
)
146+
147+
# MLP Down Projection (wo)
148+
mapping.update(
149+
{
150+
f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)),
151+
f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")),
152+
}
153+
)
154+
155+
return mapping

0 commit comments

Comments
 (0)