Skip to content

Commit 3c56dd3

Browse files
Merge pull request #3115 from AI-Hypercomputer:mhc_integration
PiperOrigin-RevId: 869924414
2 parents 98fb5cf + 0147f4b commit 3c56dd3

8 files changed

Lines changed: 172 additions & 67 deletions

File tree

src/MaxText/layers/decoders.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from MaxText.layers import normalizations
3636
from MaxText.layers import quantizations
3737
from MaxText.layers import pipeline
38+
from MaxText.layers import mhc
3839
from MaxText import sharding
3940
from MaxText.layers.attentions import attention_as_linen
4041
from MaxText.layers.normalizations import rms_norm
@@ -731,6 +732,11 @@ def __call__(
731732
audio_masks,
732733
)
733734

735+
mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate)
736+
if cfg.mhc_expansion_rate > 1:
737+
# (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim)
738+
y = mhc_expand(y)
739+
734740
policy = self.get_remat_policy()
735741
RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy)
736742
# scan does not support kwargs in layer call, passing broadcast_args as positional arg
@@ -927,7 +933,11 @@ def __call__(
927933
assert isinstance(y, jax.Array)
928934

929935
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
930-
hidden_state = y
936+
if cfg.mhc_expansion_rate > 1:
937+
# (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim)
938+
hidden_state = mhc_reduce(y)
939+
else:
940+
hidden_state = y
931941

932942
# When initializing with vLLM RPA attention, we need to run the output head to
933943
# initialize any parameters associated with it.

src/MaxText/layers/deepseek.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
import jax.numpy as jnp
2424
from jax.sharding import Mesh
2525
from MaxText.common_types import Config
26-
from MaxText.common_types import MODEL_MODE_PREFILL
26+
from MaxText.common_types import MODEL_MODE_PREFILL, HyperConnectionType
2727
from MaxText.layers import attention_mla
2828
from MaxText.layers import deepseek_batchsplit
2929
from MaxText.layers import initializers
3030
from MaxText.layers import linears
3131
from MaxText.layers import moe
3232
from MaxText.layers import nnx_wrappers
3333
from MaxText.layers import quantizations
34+
from MaxText.layers import mhc
3435
from MaxText.layers.linears import Dropout
3536
from MaxText.layers.normalizations import RMSNorm
3637
from MaxText.sharding import create_sharding
@@ -64,6 +65,7 @@ def __init__(
6465
self.mesh = mesh
6566
self.quant = quant
6667
self.rngs = rngs
68+
self.is_mhc_enabled = config.mhc_expansion_rate > 1
6769

6870
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
6971
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)
@@ -122,6 +124,9 @@ def __init__(
122124
)
123125

124126
self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)
127+
if self.is_mhc_enabled:
128+
self.mhc_attention = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs)
129+
self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs)
125130

126131
def mlp_op(self, x, deterministic, *args, **kwargs):
127132
"""Executes the MLP operation. To be implemented by subclasses."""
@@ -172,31 +177,17 @@ def attention_op(
172177

173178
@property
174179
def logical_axis_names(self):
175-
if self.model_mode == MODEL_MODE_PREFILL:
176-
return (
177-
"activation_batch",
178-
"prefill_activation_norm_length",
179-
"activation_embed",
180-
)
181-
return (
182-
"activation_batch",
183-
"activation_norm_length",
184-
"activation_embed",
185-
)
180+
"""Generate logical names for activations generally."""
181+
length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length"
182+
axis_names = ["activation_batch", length_name, "activation_embed"]
183+
return axis_names
186184

187185
@property
188186
def mlp_logical_axis_names(self):
189-
if self.model_mode == MODEL_MODE_PREFILL:
190-
return (
191-
"activation_batch",
192-
"prefill_activation_norm_length",
193-
"activation_mlp",
194-
)
195-
return (
196-
"activation_batch",
197-
"activation_norm_length",
198-
"activation_mlp",
199-
)
187+
"""Generate logical names for activations in MLP."""
188+
length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length"
189+
axis_names = ["activation_batch", length_name, "activation_mlp"]
190+
return axis_names
200191

201192
def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None):
202193
"""postprocessing."""
@@ -231,18 +222,33 @@ def self_attention_with_norm_op(
231222
slot: None | int = None,
232223
):
233224
"""self-attention with normalization"""
234-
lnx = self.pre_attention_norm_op(inputs)
235-
236-
attention_lnx = self.attention_op(
237-
lnx,
238-
decoder_segment_ids,
239-
decoder_positions,
240-
deterministic,
241-
previous_chunk,
242-
page_state,
243-
slot,
244-
)
245-
intermediate_inputs = inputs + attention_lnx
225+
if self.is_mhc_enabled:
226+
intermediate_inputs, _ = self.mhc_attention(
227+
self.pre_attention_norm_op,
228+
self.self_attention,
229+
x=inputs,
230+
mhc_type=HyperConnectionType.ATTENTION,
231+
decoder_segment_ids=decoder_segment_ids,
232+
inputs_positions=decoder_positions,
233+
deterministic=deterministic,
234+
model_mode=self.model_mode,
235+
out_sharding=self.out_sharding,
236+
previous_chunk=previous_chunk,
237+
page_state=page_state,
238+
slot=slot,
239+
)
240+
else:
241+
lnx = self.pre_attention_norm_op(inputs)
242+
attention_lnx = self.attention_op(
243+
lnx,
244+
decoder_segment_ids,
245+
decoder_positions,
246+
deterministic,
247+
previous_chunk,
248+
page_state,
249+
slot,
250+
)
251+
intermediate_inputs = inputs + attention_lnx
246252
# Normalization
247253
hidden_states = self.post_attention_norm_op(intermediate_inputs)
248254
return hidden_states, intermediate_inputs
@@ -308,9 +314,17 @@ def __call__(
308314
slot,
309315
)
310316

311-
mlp_lnx = self.mlp_op(hidden_states, deterministic)
312-
313-
layer_output = mlp_lnx + intermediate_inputs
317+
if self.is_mhc_enabled:
318+
layer_output, _ = self.mhc_mlp(
319+
self.post_attention_norm_op,
320+
self.mlp,
321+
x=intermediate_inputs,
322+
mhc_type=HyperConnectionType.MLP_DENSE,
323+
deterministic=deterministic,
324+
)
325+
else:
326+
mlp_lnx = self.mlp_op(hidden_states, deterministic)
327+
layer_output = mlp_lnx + intermediate_inputs
314328
layer_output = self.dropout_op(layer_output, deterministic=deterministic)
315329

316330
return self.post_process(layer_output, None, None, kv_cache)
@@ -394,9 +408,18 @@ def __call__(
394408
slot,
395409
)
396410

397-
mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)
398-
399-
layer_output = mlp_lnx + intermediate_inputs
411+
if self.is_mhc_enabled:
412+
layer_output, metadata = self.mhc_mlp(
413+
self.post_attention_norm_op,
414+
self.DeepSeekMoeBlock_0,
415+
x=intermediate_inputs,
416+
mhc_type=HyperConnectionType.MLP_MOE,
417+
)
418+
load_balance_loss = metadata["load_balance_loss"]
419+
moe_bias_updates = metadata["moe_bias_updates"]
420+
else:
421+
mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)
422+
layer_output = mlp_lnx + intermediate_inputs
400423
layer_output = self.dropout_op(layer_output, deterministic=deterministic)
401424

402425
return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache)

src/MaxText/layers/mhc.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def get_functions(expansion_rate: int):
3434

3535
def expand(x: Array):
3636
# (batch, length, dim) -> (batch, length, streams, dim)
37-
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2)
37+
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2).astype(x.dtype)
3838

3939
def reduce(x: Array):
4040
# (batch, length, streams, dim) -> (batch, length, dim)
41-
return jnp.sum(x, axis=2)
41+
return jnp.sum(x, axis=2, dtype=x.dtype)
4242

4343
return expand, reduce
4444

@@ -93,7 +93,9 @@ def __init__(
9393
self.dim = dim
9494
self.rngs = rngs
9595
self.mesh = mesh
96+
self.dtype = self.config.dtype
9697
self.weight_dtype = self.config.weight_dtype
98+
self.matmul_precision = jax.lax.Precision(self.config.matmul_precision)
9799

98100
# Norm layer
99101
self.mhc_norm = RMSNorm(
@@ -162,33 +164,42 @@ def __init__(
162164
)
163165
self.pre_beta = nnx.Param(
164166
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
165-
sharding=(None, None),
167+
sharding=(None,),
166168
)
167169
self.post_beta = nnx.Param(
168170
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
169-
sharding=(None, None),
171+
sharding=(None,),
170172
)
171173

172174
def res_mapping(self, x: Array):
173175
"""Helper function for residual mapping."""
176+
# In MaxText, we match weight precision to activations before Matmul
177+
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
178+
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
179+
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
174180
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
175-
h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision)
181+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
176182
b, s, _ = h_res.shape
177183
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
178-
intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :]
184+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
179185
output = sinkhorn(intermediate, self.sinkhorn_iterations)
180186
return output
181187

182188
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
183189
"""Helper function for both pre and post mappings."""
190+
# In MaxText, we match weight precision to activations before Matmul
191+
alpha = jnp.asarray(alpha, self.dtype)
192+
beta = jnp.asarray(beta, self.dtype)
193+
alpha_scale = jnp.asarray(alpha_scale, self.dtype)
184194
# Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
185-
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision)
195+
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision)
186196
intermediate = alpha_scale * h + beta[None, None, :]
187197
output = scale * jax.nn.sigmoid(intermediate)
188198
return output
189199

190200
def __call__(
191201
self,
202+
norm_fn: Callable,
192203
branch_fn: Callable,
193204
x: Array,
194205
mhc_type: HyperConnectionType,
@@ -197,6 +208,7 @@ def __call__(
197208
"""Applying manifold-constrained hyper connection based on callable function.
198209
199210
Args:
211+
norm_fn: The pre-normalization function to be applied.
200212
branch_fn: The function to be wrapped by the hyper-connection.
201213
x: Input tensor of shape `(batch..., dim)`.
202214
mhc_type: The variant of the connection to apply.
@@ -212,24 +224,30 @@ def __call__(
212224
norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d)))
213225

214226
# 2. Pre mapping
215-
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0)
216-
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision)
227+
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale[...], self.pre_alpha[...], self.pre_beta[...], 1.0)
228+
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision)
229+
230+
# 3. Pre-norm
231+
layer_input = norm_fn(layer_input)
217232

218-
# 3. Attention or MLP
233+
# 4. Attention or MLP
234+
metadata = {}
219235
if mhc_type == HyperConnectionType.ATTENTION:
220236
layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs)
221237
elif mhc_type == HyperConnectionType.MLP_DENSE:
222238
layer_out = branch_fn(inputs=layer_input, **kwargs)
223239
elif mhc_type == HyperConnectionType.MLP_MOE:
224-
layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs)
240+
layer_out, load_balance_loss, moe_bias_updates = branch_fn(inputs=layer_input, **kwargs)
241+
metadata["load_balance_loss"] = load_balance_loss
242+
metadata["moe_bias_updates"] = moe_bias_updates
225243
else:
226244
raise ValueError(f"Unsupported type: {mhc_type}")
227245

228-
# 4. Post mapping
229-
post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0)
230-
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision)
246+
# 5. Post mapping
247+
post_mapping = self.mapping(norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0)
248+
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.matmul_precision)
231249

232-
# 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
250+
# 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
233251
res_mapping = self.res_mapping(norm_x)
234-
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision)
235-
return res_out + post_out
252+
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision)
253+
return res_out + post_out, metadata

src/MaxText/train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,23 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
197197
# get MoE load balance loss
198198
moe_lb_loss = 0.0
199199
if config.num_experts > 1:
200-
nested_key = ("intermediates", "decoder", "layers", "moe_lb_loss")
201-
total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0)
200+
# Note: the key is affected by the model implementation
201+
possible_keys = [
202+
("intermediates", "decoder", "layers", "moe_lb_loss"),
203+
("intermediates", "decoder", "moe_layers", "moe_lb_loss"),
204+
]
205+
206+
total_moe_lb_loss = 0.0
207+
found_loss = False
208+
for nested_key in possible_keys:
209+
total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0)
210+
if total_moe_lb_loss != 0.0:
211+
found_loss = True
212+
break
213+
214+
if not found_loss:
215+
max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.")
216+
202217
moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss))
203218
loss += moe_lb_loss
204219

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,6 @@ force_q_layout: false
10901090

10911091
################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ##################################
10921092
# The number of parallel streams in Hyper Connection.
1093-
mhc_expansion_rate: 0
1093+
mhc_expansion_rate: 1
10941094
# The number of iterations for the Sinkhorn-Knopp algorithm.
10951095
sinkhorn_iterations: 20

src/maxtext/configs/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class ProfilerType(str, Enum):
248248
"llama4-17b-16e",
249249
"llama4-17b-128e",
250250
"olmo3-7b",
251-
'olmo3-7b-pt',
251+
"olmo3-7b-pt",
252252
"olmo3-32b",
253253
]
254254

@@ -1085,7 +1085,7 @@ class TrainingLoop(BaseModel):
10851085
class ManifoldConstrainedHyperConnections(BaseModel):
10861086
"""Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
10871087

1088-
mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.")
1088+
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
10891089
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
10901090

10911091

0 commit comments

Comments
 (0)