Skip to content

Commit 2499b2d

Browse files
add fp32 layer norm
1 parent 08444fd commit 2499b2d

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

src/maxdiffusion/models/normalization_flax.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import jax
1818
import jax.numpy as jnp
1919
import flax.linen as nn
20+
from flax import nnx
2021

2122

2223
class AdaLayerNormContinuous(nn.Module):
@@ -147,3 +148,20 @@ def __call__(self, x, emb):
147148
else:
148149
raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.")
149150
return x, gate_msa
151+
152+
class FP32LayerNorm(nnx.Module):
153+
def __init__(self, rngs: nnx.Rngs, dim: int, eps : float, elementwise_affine: bool):
154+
self.layer_norm = nnx.LayerNorm(
155+
rngs=rngs,
156+
num_features=dim,
157+
epsilon=eps,
158+
use_bias=elementwise_affine,
159+
use_scale=elementwise_affine,
160+
param_dtype=jnp.float32,
161+
dtype=jnp.float32
162+
)
163+
def __call__(self, inputs: jax.Array) -> jax.Array:
164+
origin_dtype = inputs.dtype
165+
return self.layer_norm(
166+
inputs.astype(dtype=jnp.float32)
167+
).astype(dtype=origin_dtype)

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding
2424
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
25+
from ..models.normalization_flax import FP32LayerNorm
2526

2627
class WanTransformerTest(unittest.TestCase):
2728
def setUp(self):
@@ -68,6 +69,21 @@ def test_nnx_timestep_embedding(self):
6869
dummy_output = layer(dummy_sample)
6970
assert dummy_output.shape == (1, 5120)
7071

72+
def test_fp32_layer_norm(self):
73+
key = jax.random.key(0)
74+
rngs = nnx.Rngs(key)
75+
batch_size = 1
76+
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
77+
# expected same output shape with same dtype
78+
layer = FP32LayerNorm(
79+
rngs=rngs,
80+
dim=5120,
81+
eps=1e-6,
82+
elementwise_affine=False
83+
)
84+
dummy_output = layer(dummy_hidden_states)
85+
assert dummy_output.shape == dummy_hidden_states.shape
86+
7187
def test_wan_time_text_embedding(self):
7288
key = jax.random.key(0)
7389
rngs = nnx.Rngs(key)

0 commit comments

Comments
 (0)