Skip to content

Commit 45f5bf5

Browse files
committed
Fix: Added sharding Constraints for MTP block (b/481469708)
1 parent ef90f2d commit 45f5bf5

2 files changed

Lines changed: 25 additions & 14 deletions

File tree

src/MaxText/layers/multi_token_prediction.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616

1717
from typing import Type
1818

19+
from flax import linen as nn
20+
from flax import nnx
1921
import jax
2022
import jax.numpy as jnp
2123
from jax.sharding import Mesh
2224

23-
from flax import linen as nn
24-
from flax import nnx
25-
25+
from MaxText import sharding
2626
from MaxText.common_types import Config, MODEL_MODE_TRAIN
27-
from MaxText.layers.linears import DenseGeneral
28-
from MaxText.layers.normalizations import RMSNorm
29-
from MaxText.layers.decoders import DecoderLayer
30-
from MaxText.layers import nnx_wrappers
3127
from MaxText.globals import EPS
28+
from MaxText.layers import nnx_wrappers
29+
from MaxText.layers.decoders import DecoderLayer
3230
from MaxText.layers.initializers import variable_to_logically_partitioned
31+
from MaxText.layers.linears import DenseGeneral
32+
from MaxText.layers.normalizations import RMSNorm
33+
3334
from maxtext.utils import max_utils
3435
from maxtext.utils import maxtext_utils
3536

@@ -84,24 +85,24 @@ def __init__(
8485
cfg = self.config
8586

8687
self.embedding_norm = RMSNorm(
87-
num_features=cfg.base_emb_dim,
88+
num_features=cfg.emb_dim,
8889
epsilon=cfg.normalization_layer_epsilon,
8990
dtype=cfg.dtype,
9091
weight_dtype=cfg.weight_dtype,
9192
kernel_axes=("norm",),
9293
rngs=rngs,
9394
)
9495
self.hidden_state_norm = RMSNorm(
95-
num_features=cfg.base_emb_dim,
96+
num_features=cfg.emb_dim,
9697
epsilon=cfg.normalization_layer_epsilon,
9798
dtype=cfg.dtype,
9899
weight_dtype=cfg.weight_dtype,
99100
kernel_axes=("norm",),
100101
rngs=rngs,
101102
)
102103
self.projection_layer = DenseGeneral(
103-
in_features_shape=2 * cfg.base_emb_dim,
104-
out_features_shape=cfg.base_emb_dim,
104+
in_features_shape=2 * cfg.emb_dim,
105+
out_features_shape=cfg.emb_dim,
105106
dtype=cfg.dtype,
106107
weight_dtype=cfg.weight_dtype,
107108
use_bias=False,
@@ -118,10 +119,11 @@ def __init__(
118119
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
119120

120121
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
122+
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
121123
self.transformer_layer.lazy_init(
122-
inputs=jnp.zeros((1, 1, cfg.base_emb_dim), dtype=cfg.dtype),
124+
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
123125
decoder_segment_ids=None,
124-
decoder_positions=jnp.zeros((1, 1), dtype=jnp.int32),
126+
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
125127
deterministic=True,
126128
model_mode=MODEL_MODE_TRAIN,
127129
)
@@ -149,6 +151,14 @@ def __call__(
149151
Returns:
150152
Processed hidden state. Shape [batch, seq_len, hidden_size].
151153
"""
154+
target_token_embedding = sharding.maybe_shard_with_logical(
155+
target_token_embedding,
156+
("activation_batch", "activation_length", "activation_embed"),
157+
self.mesh,
158+
self.config.shard_mode,
159+
self.config.logical_axis_rules,
160+
)
161+
152162
embedding_norm = self.embedding_norm(target_token_embedding)
153163
hidden_state_norm = self.hidden_state_norm(prev_hidden_state)
154164
concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1)

tests/unit/multi_token_prediction_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class MultiTokenPredictionBlockTest(unittest.TestCase):
199199
def setUp(self):
200200
super().setUp()
201201
# Conditionally set ici_fsdp_parallelism to match device count in decoupled mode
202+
num_devices = jax.device_count()
202203
extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
203204
self.cfg = pyconfig.initialize(
204205
[None, get_test_config_path()],
@@ -215,7 +216,7 @@ def setUp(self):
215216
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
216217
data_rng, self.init_rng = jax.random.split(self.rng)
217218

218-
self.batch_size, self.seq_len, self.embed_dim = 2, 8, self.cfg.base_emb_dim
219+
self.batch_size, self.seq_len, self.embed_dim = num_devices, 8, self.cfg.base_emb_dim
219220
key1, key2, key3 = jax.random.split(data_rng, 3)
220221
self.main_hidden_state = jax.random.normal(key1, (self.batch_size, self.seq_len, self.embed_dim))
221222
self.input_ids = jax.random.randint(key2, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size)

0 commit comments

Comments
 (0)