Skip to content

Commit 1d02688

Browse files
committed
Onboard DeepSeek MHC feature
1 parent 27eada9 commit 1d02688

7 files changed

Lines changed: 519 additions & 0 deletions

File tree

src/MaxText/common_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,9 @@ class AttentionType(enum.Enum):
114114
class ShardMode(enum.Enum):
115115
AUTO = "auto" # default
116116
EXPLICIT = "explicit"
117+
118+
119+
class HyperConnectionType(enum.Enum):
120+
ATTENTION = "attention"
121+
MLP_MOE = "mlp_moe"
122+
MLP_DENSE = "mlp_dense"

src/MaxText/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,3 +1065,9 @@ vllm_hf_config_path: ""
10651065
vllm_additional_config: {}
10661066
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
10671067
force_q_layout: false
1068+
1069+
################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ##################################
1070+
# The number of parallel streams in Hyper Connection.
1071+
mhc_expansion_rate: 0
1072+
# The number of iterations for the Sinkhorn-Knopp algorithm.
1073+
sinkhorn_iterations: 20
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Small model config for testing (derived from DeepSeek V3.2 - 671B)
16+
17+
base_emb_dim: 1024 # Reduced from 7168
18+
base_num_query_heads: 16 # Reduced from 128
19+
base_num_kv_heads: 16 # Reduced from 128
20+
base_mlp_dim: 2048 # Reduced from 18432
21+
base_moe_mlp_dim: 512 # Reduced from 2048
22+
base_num_decoder_layers: 6 # Reduced from 61
23+
first_num_dense_layers: 1 # Reduced from 3
24+
mlp_activations: ["silu","linear"]
25+
vocab_size: 129280
26+
enable_dropout: False
27+
logits_via_embedding: False
28+
normalization_layer_epsilon: 1.0e-6
29+
num_experts: 16 # Reduced from 256
30+
num_experts_per_tok: 2 # Reduced from 8
31+
shared_experts: 1
32+
routed_scaling_factor: 2.5
33+
routed_score_func: "sigmoid"
34+
routed_bias: True
35+
decoder_block: "deepseek"
36+
# MLA
37+
attention_type: "mla"
38+
q_lora_rank: 384 # Reduced from 1536
39+
kv_lora_rank: 128 # Reduced from 512
40+
qk_nope_head_dim: 32 # Reduced from 128
41+
qk_rope_head_dim: 16 # Reduced from 64
42+
v_head_dim: 128
43+
# RoPE
44+
mscale: 1.0
45+
rope_type: "yarn"
46+
rope_max_timescale: 10_000
47+
max_position_embeddings: 4096 # Reduced for local testing
48+
original_max_position_embeddings: 4096
49+
rope_factor: 1
50+
beta_fast: 32
51+
rope_interleave: True
52+
rope_truncate: True
53+
rope_attention_scaling: False
54+
# Indexer for DeepSeek Sparse Attention
55+
use_sparse_indexer: True
56+
index_n_heads: 16 # Reduced from 64
57+
index_head_dim: 64 # Reduced from 128
58+
index_topk: 256 # Reduced from 2048
59+
# Hyper-connections: mHC enabled
60+
mhc_expansion_rate: 4
61+
sinkhorn_iterations: 20

src/MaxText/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class ProfilerType(str, Enum):
210210
"deepseek3-test",
211211
"deepseek3-tiny",
212212
"deepseek3.2-671b",
213+
"deepseek-custom",
213214
"kimi-k2-1t",
214215
"gemma-7b",
215216
"gemma-2b",
@@ -1057,6 +1058,13 @@ class TrainingLoop(BaseModel):
10571058
init_weights_seed: int = Field(0, description="Seed for model weight initialization.")
10581059

10591060

1061+
class ManifoldConstrainedHyperConnections(BaseModel):
1062+
"""Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1063+
1064+
mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.")
1065+
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1066+
1067+
10601068
class Optimizer(BaseModel):
10611069
"""Configuration for the optimizer and learning rate schedule."""
10621070

@@ -1743,6 +1751,7 @@ class MaxTextConfig(
17431751
# Training, Optimization, and Fine-Tuning
17441752
RematAndOffload,
17451753
TrainingLoop,
1754+
ManifoldConstrainedHyperConnections,
17461755
Optimizer,
17471756
AdamW,
17481757
Muon,

src/MaxText/layers/initializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
3232

3333
default_bias_init = jax.nn.initializers.constant(0.0)
34+
default_scalar_init = jax.nn.initializers.constant(0.01)
3435

3536

3637
def nd_dense_init(scale, mode, distribution):

src/MaxText/layers/mhc.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
16+
17+
import jax
18+
from jax.sharding import Mesh
19+
20+
import jax.numpy as jnp
21+
from flax import nnx
22+
from typing import Callable
23+
from MaxText.common_types import Config, Array
24+
from MaxText.layers.normalizations import RMSNorm
25+
from MaxText.layers.initializers import nd_dense_init, default_bias_init, default_scalar_init
26+
from MaxText.common_types import HyperConnectionType
27+
28+
29+
def get_functions(expansion_rate: int):
30+
"""
31+
Creates functions to broadcast a single feature stream into multiple
32+
parallel paths (expand) and aggregate them back (reduce).
33+
"""
34+
35+
def expand(x: Array):
36+
# (batch, length, dim) -> (batch, length, streams, dim)
37+
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2)
38+
39+
def reduce(x: Array):
40+
# (batch, length, streams, dim) -> (batch, length, dim)
41+
return jnp.sum(x, axis=2)
42+
43+
return expand, reduce
44+
45+
46+
def sinkhorn(t, iters=20):
47+
"""
48+
Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).
49+
"""
50+
# Use float32 precision for numerical stability during normalization
51+
initial_dtype = t.dtype
52+
t = t.astype(jnp.float32)
53+
54+
# Column-wise normalization (axis=-2) - positive and sum up to 1 across columns
55+
# Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2)
56+
t = jax.nn.softmax(t, axis=-2)
57+
58+
def body_fun(i, val):
59+
# L1 Normalization: val / sum(val) with clipping of denominator
60+
# Normalize rows (axis -1)
61+
val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12)
62+
# Normalize columns (axis -2)
63+
val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12)
64+
return val
65+
66+
# Use lax.fori_loop for an efficient, JIT-friendly loop
67+
t = jax.lax.fori_loop(0, iters, body_fun, t)
68+
return t.astype(initial_dtype)
69+
70+
71+
class ManifoldConstrainedHyperConnections(nnx.Module):
72+
"""Implements Manifold-Constrained Hyper-Connections (mHC).
73+
74+
Reference: https://arxiv.org/pdf/2512.24880
75+
76+
Args:
77+
config: Configuration object containing hyperparameters.
78+
dim: The feature dimensionality.
79+
mesh: The hardware mesh for sharding.
80+
rngs: Random number generation in NNX.
81+
"""
82+
83+
def __init__(
84+
self,
85+
config: Config,
86+
dim: int,
87+
mesh: Mesh,
88+
rngs: nnx.Rngs,
89+
):
90+
self.config = config
91+
self.sinkhorn_iterations = config.sinkhorn_iterations
92+
self.k = config.mhc_expansion_rate
93+
self.dim = dim
94+
self.rngs = rngs
95+
self.mesh = mesh
96+
self.weight_dtype = self.config.weight_dtype
97+
98+
# Norm layer
99+
self.mhc_norm = RMSNorm(
100+
num_features=self.k * self.dim,
101+
dtype=self.config.dtype,
102+
weight_dtype=self.weight_dtype,
103+
kernel_axes=("norm",),
104+
epsilon=self.config.normalization_layer_epsilon,
105+
rngs=self.rngs,
106+
)
107+
108+
# Scalars
109+
self.res_alpha_scale = nnx.Param(
110+
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
111+
sharding=(None,),
112+
)
113+
self.pre_alpha_scale = nnx.Param(
114+
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
115+
sharding=(None,),
116+
)
117+
self.post_alpha_scale = nnx.Param(
118+
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
119+
sharding=(None,),
120+
)
121+
122+
# Weight matrices
123+
scale_init = nd_dense_init(1.0, "fan_in", "normal")
124+
in_axis = 0
125+
out_axis = 1
126+
weight_sharding_axis_name = ("activation_embed", None)
127+
self.res_alpha = nnx.Param(
128+
scale_init(
129+
self.rngs.params(),
130+
(self.k * self.dim, self.k * self.k),
131+
self.weight_dtype,
132+
in_axis=in_axis,
133+
out_axis=out_axis,
134+
),
135+
sharding=weight_sharding_axis_name,
136+
)
137+
self.pre_alpha = nnx.Param(
138+
scale_init(
139+
self.rngs.params(),
140+
(self.k * self.dim, self.k),
141+
self.weight_dtype,
142+
in_axis=in_axis,
143+
out_axis=out_axis,
144+
),
145+
sharding=weight_sharding_axis_name,
146+
)
147+
self.post_alpha = nnx.Param(
148+
scale_init(
149+
self.rngs.params(),
150+
(self.k * self.dim, self.k),
151+
self.weight_dtype,
152+
in_axis=in_axis,
153+
out_axis=out_axis,
154+
),
155+
sharding=weight_sharding_axis_name,
156+
)
157+
158+
# Biases
159+
self.res_beta = nnx.Param(
160+
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
161+
sharding=(None, None),
162+
)
163+
self.pre_beta = nnx.Param(
164+
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
165+
sharding=(None, None),
166+
)
167+
self.post_beta = nnx.Param(
168+
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
169+
sharding=(None, None),
170+
)
171+
172+
def res_mapping(self, x: Array):
173+
"""Helper function for residual mapping."""
174+
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
175+
h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision)
176+
b, s, _ = h_res.shape
177+
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
178+
intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :]
179+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
180+
return output
181+
182+
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
183+
"""Helper function for both pre and post mappings."""
184+
# Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
185+
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision)
186+
intermediate = alpha_scale * h + beta[None, None, :]
187+
output = scale * jax.nn.sigmoid(intermediate)
188+
return output
189+
190+
def __call__(
191+
self,
192+
branch_fn: Callable,
193+
x: Array,
194+
mhc_type: HyperConnectionType,
195+
**kwargs,
196+
) -> Array:
197+
"""Applying manifold-constrained hyper connection based on callable function.
198+
199+
Args:
200+
branch_fn: The function to be wrapped by the hyper-connection.
201+
x: Input tensor of shape `(batch..., dim)`.
202+
mhc_type: The variant of the connection to apply.
203+
**kwargs: Additional context passed to the branch function.
204+
205+
Returns:
206+
The processed tensor, maintaining the shape of `x`.
207+
"""
208+
# x shape: [batch, seq, expansion_rate, emb]
209+
b, s, k, d = x.shape
210+
211+
# 1. Flatten the tensor, and RMS normalization
212+
norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d)))
213+
214+
# 2. Pre mapping
215+
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0)
216+
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision)
217+
218+
# 3. Attention or MLP
219+
if mhc_type == HyperConnectionType.ATTENTION:
220+
layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs)
221+
elif mhc_type == HyperConnectionType.MLP_DENSE:
222+
layer_out = branch_fn(inputs=layer_input, **kwargs)
223+
elif mhc_type == HyperConnectionType.MLP_MOE:
224+
layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs)
225+
else:
226+
raise ValueError(f"Unsupported type: {mhc_type}")
227+
228+
# 4. Post mapping
229+
post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0)
230+
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision)
231+
232+
# 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
233+
res_mapping = self.res_mapping(norm_x)
234+
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision)
235+
return res_out + post_out

0 commit comments

Comments
 (0)