@@ -34,11 +34,11 @@ def get_functions(expansion_rate: int):
3434
3535 def expand (x : Array ):
3636 # (batch, length, dim) -> (batch, length, streams, dim)
37- return jnp .repeat (jnp .expand_dims (x , axis = 2 ), expansion_rate , axis = 2 )
37+ return jnp .repeat (jnp .expand_dims (x , axis = 2 ), expansion_rate , axis = 2 ). astype ( x . dtype )
3838
3939 def reduce (x : Array ):
4040 # (batch, length, streams, dim) -> (batch, length, dim)
41- return jnp .sum (x , axis = 2 )
41+ return jnp .sum (x , axis = 2 , dtype = x . dtype )
4242
4343 return expand , reduce
4444
@@ -93,7 +93,9 @@ def __init__(
9393 self .dim = dim
9494 self .rngs = rngs
9595 self .mesh = mesh
96+ self .dtype = self .config .dtype
9697 self .weight_dtype = self .config .weight_dtype
98+ self .matmul_precision = jax .lax .Precision (self .config .matmul_precision )
9799
98100 # Norm layer
99101 self .mhc_norm = RMSNorm (
@@ -162,33 +164,42 @@ def __init__(
162164 )
163165 self .pre_beta = nnx .Param (
164166 default_bias_init (self .rngs .params (), (self .k ,), self .weight_dtype ),
165- sharding = (None , None ),
167+ sharding = (None ,),
166168 )
167169 self .post_beta = nnx .Param (
168170 default_bias_init (self .rngs .params (), (self .k ,), self .weight_dtype ),
169- sharding = (None , None ),
171+ sharding = (None ,),
170172 )
171173
172174 def res_mapping (self , x : Array ):
173175 """Helper function for residual mapping."""
176+ # In MaxText, we match weight precision to activations before Matmul
177+ res_alpha = jnp .asarray (self .res_alpha [...], self .dtype )
178+ res_beta = jnp .asarray (self .res_beta [...], self .dtype )
179+ res_alpha_scale = jnp .asarray (self .res_alpha_scale [...], self .dtype )
174180 # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
175- h_res = jnp .einsum ("bsm,mn -> bsn" , x , self . res_alpha [...] , precision = self . config .matmul_precision )
181+ h_res = jnp .einsum ("bsm,mn -> bsn" , x , res_alpha , precision = self .matmul_precision )
176182 b , s , _ = h_res .shape
177183 h_res = jnp .reshape (h_res , (b , s , self .k , self .k ))
178- intermediate = self . res_alpha_scale * h_res + self . res_beta [...] [None , None , :, :]
184+ intermediate = res_alpha_scale * h_res + res_beta [None , None , :, :]
179185 output = sinkhorn (intermediate , self .sinkhorn_iterations )
180186 return output
181187
182188 def mapping (self , x : Array , alpha_scale : Array , alpha : Array , beta : Array , scale : int ):
183189 """Helper function for both pre and post mappings."""
190+ # In MaxText, we match weight precision to activations before Matmul
191+ alpha = jnp .asarray (alpha , self .dtype )
192+ beta = jnp .asarray (beta , self .dtype )
193+ alpha_scale = jnp .asarray (alpha_scale , self .dtype )
184194 # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
185- h = jnp .einsum ("bsm,mk -> bsk" , x , alpha , precision = self .config . matmul_precision )
195+ h = jnp .einsum ("bsm,mk -> bsk" , x , alpha , precision = self .matmul_precision )
186196 intermediate = alpha_scale * h + beta [None , None , :]
187197 output = scale * jax .nn .sigmoid (intermediate )
188198 return output
189199
190200 def __call__ (
191201 self ,
202+ norm_fn : Callable ,
192203 branch_fn : Callable ,
193204 x : Array ,
194205 mhc_type : HyperConnectionType ,
@@ -197,6 +208,7 @@ def __call__(
197208 """Applying manifold-constrained hyper connection based on callable function.
198209
199210 Args:
211+ norm_fn: The pre-normalization function to be applied.
200212 branch_fn: The function to be wrapped by the hyper-connection.
201213 x: Input tensor of shape `(batch..., dim)`.
202214 mhc_type: The variant of the connection to apply.
@@ -212,24 +224,30 @@ def __call__(
212224 norm_x = self .mhc_norm (jnp .reshape (x , (b , s , k * d )))
213225
214226 # 2. Pre mapping
215- pre_mapping = self .mapping (norm_x , self .pre_alpha_scale , self .pre_alpha [...], self .pre_beta [...], 1.0 )
216- layer_input = jnp .einsum ("bskd,bsk -> bsd" , x , pre_mapping , precision = self .config .matmul_precision )
227+ pre_mapping = self .mapping (norm_x , self .pre_alpha_scale [...], self .pre_alpha [...], self .pre_beta [...], 1.0 )
228+ layer_input = jnp .einsum ("bskd,bsk -> bsd" , x , pre_mapping , precision = self .matmul_precision )
229+
230+ # 3. Pre-norm
231+ layer_input = norm_fn (layer_input )
217232
218- # 3. Attention or MLP
233+ # 4. Attention or MLP
234+ metadata = {}
219235 if mhc_type == HyperConnectionType .ATTENTION :
220236 layer_out , _ = branch_fn (inputs_q = layer_input , inputs_kv = layer_input , ** kwargs )
221237 elif mhc_type == HyperConnectionType .MLP_DENSE :
222238 layer_out = branch_fn (inputs = layer_input , ** kwargs )
223239 elif mhc_type == HyperConnectionType .MLP_MOE :
224- layer_out , _ , _ = branch_fn (inputs = layer_input , ** kwargs )
240+ layer_out , load_balance_loss , moe_bias_updates = branch_fn (inputs = layer_input , ** kwargs )
241+ metadata ["load_balance_loss" ] = load_balance_loss
242+ metadata ["moe_bias_updates" ] = moe_bias_updates
225243 else :
226244 raise ValueError (f"Unsupported type: { mhc_type } " )
227245
228- # 4 . Post mapping
229- post_mapping = self .mapping (norm_x , self .post_alpha_scale , self .post_alpha [...], self .post_beta [...], 2.0 )
230- post_out = jnp .einsum ("bsd,bsk -> bskd" , layer_out , post_mapping , precision = self .config . matmul_precision )
246+ # 5 . Post mapping
247+ post_mapping = self .mapping (norm_x , self .post_alpha_scale [...] , self .post_alpha [...], self .post_beta [...], 2.0 )
248+ post_out = jnp .einsum ("bsd,bsk -> bskd" , layer_out , post_mapping , precision = self .matmul_precision )
231249
232- # 5 . Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
250+ # 6 . Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
233251 res_mapping = self .res_mapping (norm_x )
234- res_out = jnp .einsum ("bskd,bskm -> bsmd" , x , res_mapping , precision = self .config . matmul_precision )
235- return res_out + post_out
252+ res_out = jnp .einsum ("bskd,bskm -> bsmd" , x , res_mapping , precision = self .matmul_precision )
253+ return res_out + post_out , metadata
0 commit comments