1616
1717from typing import Type
1818
19+ from flax import linen as nn
20+ from flax import nnx
1921import jax
2022import jax .numpy as jnp
2123from jax .sharding import Mesh
2224
23- from flax import linen as nn
24- from flax import nnx
25-
25+ from MaxText import sharding
2626from MaxText .common_types import Config , MODEL_MODE_TRAIN
27- from MaxText .layers .linears import DenseGeneral
28- from MaxText .layers .normalizations import RMSNorm
29- from MaxText .layers .decoders import DecoderLayer
30- from MaxText .layers import nnx_wrappers
3127from MaxText .globals import EPS
28+ from MaxText .layers import nnx_wrappers
29+ from MaxText .layers .decoders import DecoderLayer
3230from MaxText .layers .initializers import variable_to_logically_partitioned
31+ from MaxText .layers .linears import DenseGeneral
32+ from MaxText .layers .normalizations import RMSNorm
33+
3334from maxtext .utils import max_utils
3435from maxtext .utils import maxtext_utils
3536
@@ -84,24 +85,24 @@ def __init__(
8485 cfg = self .config
8586
8687 self .embedding_norm = RMSNorm (
87- num_features = cfg .base_emb_dim ,
88+ num_features = cfg .emb_dim ,
8889 epsilon = cfg .normalization_layer_epsilon ,
8990 dtype = cfg .dtype ,
9091 weight_dtype = cfg .weight_dtype ,
9192 kernel_axes = ("norm" ,),
9293 rngs = rngs ,
9394 )
9495 self .hidden_state_norm = RMSNorm (
95- num_features = cfg .base_emb_dim ,
96+ num_features = cfg .emb_dim ,
9697 epsilon = cfg .normalization_layer_epsilon ,
9798 dtype = cfg .dtype ,
9899 weight_dtype = cfg .weight_dtype ,
99100 kernel_axes = ("norm" ,),
100101 rngs = rngs ,
101102 )
102103 self .projection_layer = DenseGeneral (
103- in_features_shape = 2 * cfg .base_emb_dim ,
104- out_features_shape = cfg .base_emb_dim ,
104+ in_features_shape = 2 * cfg .emb_dim ,
105+ out_features_shape = cfg .emb_dim ,
105106 dtype = cfg .dtype ,
106107 weight_dtype = cfg .weight_dtype ,
107108 use_bias = False ,
@@ -118,10 +119,11 @@ def __init__(
118119 self .transformer_layer = nnx_wrappers .ToNNX (mtp_transformer_layer , rngs = rngs )
119120
120121 # ToNNX requires explicit initialization with sample inputs for proper parameter setup.
122+ batch_size , seq_len = max_utils .get_batch_seq_len_for_mode (config = cfg , model_mode = MODEL_MODE_TRAIN )
121123 self .transformer_layer .lazy_init (
122- inputs = jnp .zeros ((1 , 1 , cfg . base_emb_dim ), dtype = cfg .dtype ),
124+ inputs = jnp .zeros ((batch_size , seq_len , self . config . emb_dim ), dtype = self . config .dtype ),
123125 decoder_segment_ids = None ,
124- decoder_positions = jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
126+ decoder_positions = jnp .zeros ((batch_size , seq_len ), dtype = jnp .int32 ),
125127 deterministic = True ,
126128 model_mode = MODEL_MODE_TRAIN ,
127129 )
@@ -149,6 +151,14 @@ def __call__(
149151 Returns:
150152 Processed hidden state. Shape [batch, seq_len, hidden_size].
151153 """
154+ target_token_embedding = sharding .maybe_shard_with_logical (
155+ target_token_embedding ,
156+ ("activation_batch" , "activation_length" , "activation_embed" ),
157+ self .mesh ,
158+ self .config .shard_mode ,
159+ self .config .logical_axis_rules ,
160+ )
161+
152162 embedding_norm = self .embedding_norm (target_token_embedding )
153163 hidden_state_norm = self .hidden_state_norm (prev_hidden_state )
154164 concatenated_features = jnp .concatenate ([embedding_norm , hidden_state_norm ], axis = - 1 )
0 commit comments