Skip to content

Commit 377bf56

Browse files
Merge pull request #2826 from AI-Hypercomputer:chengnuojin-fix-llama2
PiperOrigin-RevId: 843918772
2 parents 79eecc9 + 7fc238e commit 377bf56

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

src/MaxText/layers/normalizations.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
7777

7878
scale = jnp.asarray(scale, self.dtype)
7979
effective_scale = scale + self.scale_offset # Apply offset
80-
# y: (B, S, E)
81-
# effective_scale: (E,) -> (1, 1, E) -> (B, S, E)
82-
effective_scale = jnp.expand_dims(effective_scale, axis=tuple(range(y.ndim - effective_scale.ndim)))
83-
effective_scale = jnp.broadcast_to(effective_scale, y.shape, out_sharding=out_sharding)
84-
return jnp.multiply(y, effective_scale)
80+
return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding)
8581

8682

8783
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):

0 commit comments

Comments
 (0)