Skip to content

Commit 08444fd

Browse files
add wan time text embedding layer.
1 parent 064fc5f commit 08444fd

4 files changed

Lines changed: 118 additions & 8 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,51 @@ def get_1d_rotary_pos_embed(
230230
out = jax.lax.complex(jnp.ones_like(freqs), freqs)
231231
return out
232232

233+
class NNXPixArtAlphaTextProjection(nnx.Module):
234+
def __init__(
235+
self,
236+
rngs: nnx.Rngs,
237+
in_features: int,
238+
hidden_size: int,
239+
out_features: int = None,
240+
act_fn: str = "gelu_tanh",
241+
dtype: jnp.dtype = jnp.float32,
242+
weights_dtype: jnp.dtype = jnp.float32,
243+
precision: jax.lax.Precision = None
244+
):
245+
if out_features is None:
246+
out_features = hidden_size
247+
248+
self.linear_1 = nnx.Linear(
249+
rngs=rngs,
250+
in_features=in_features,
251+
out_features=hidden_size,
252+
use_bias=True,
253+
dtype=dtype,
254+
param_dtype=weights_dtype,
255+
precision=precision,
256+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
257+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
258+
)
259+
self.act_1 = get_activation(act_fn)
260+
261+
self.linear_2 = nnx.Linear(
262+
rngs=rngs,
263+
in_features=hidden_size,
264+
out_features=out_features,
265+
use_bias=True,
266+
dtype=dtype,
267+
param_dtype=weights_dtype,
268+
precision=precision,
269+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)),
270+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
271+
)
272+
273+
def __call__(self, caption):
274+
hidden_states = self.linear_1(caption)
275+
hidden_states = self.act_1(hidden_states)
276+
hidden_states = self.linear_2(hidden_states)
277+
return hidden_states
233278

234279
class PixArtAlphaTextProjection(nn.Module):
235280
"""

src/maxdiffusion/models/modeling_flax_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242

4343

4444
logger = logging.get_logger(__name__)
45-
46-
_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish}
45+
# gelu and gelu_tanh both use approximate=True by default
46+
_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "gelu_tanh" : jax.nn.gelu, "mish": jax.nn.mish}
4747

4848
def get_activation(name: str):
4949
func = _ACTIVATIONS.get(name)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919
import jax.numpy as jnp
2020
from flax import nnx
2121
from .... import common_types, max_logging
22-
from ...modeling_flax_utils import FlaxModelMixin
22+
from ...modeling_flax_utils import FlaxModelMixin, get_activation
2323
from ....configuration_utils import ConfigMixin, register_to_config
24-
from ...embeddings_flax import get_1d_rotary_pos_embed, NNXFlaxTimesteps, NNXTimestepEmbedding
24+
from ...embeddings_flax import (
25+
get_1d_rotary_pos_embed,
26+
NNXFlaxTimesteps,
27+
NNXTimestepEmbedding,
28+
NNXPixArtAlphaTextProjection
29+
)
2530

2631
BlockSizes = common_types.BlockSizes
2732

@@ -101,6 +106,23 @@ def __init__(
101106
rngs=rngs, in_channels=time_freq_dim, time_embed_dim=dim,
102107
dtype=dtype, weights_dtype=weights_dtype, precision=precision
103108
)
109+
self.act_fn = get_activation("silu")
110+
self.time_proj = nnx.Linear(
111+
rngs=rngs,
112+
in_features=dim,
113+
out_features=time_proj_dim,
114+
dtype=dtype,
115+
param_dtype=weights_dtype,
116+
precision=precision,
117+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
118+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
119+
)
120+
self.text_embedder = NNXPixArtAlphaTextProjection(
121+
rngs=rngs,
122+
in_features=text_embed_dim,
123+
hidden_size=dim,
124+
act_fn="gelu_tanh",
125+
)
104126

105127
def __call__(
106128
self,
@@ -110,7 +132,13 @@ def __call__(
110132
):
111133
timestep = self.timesteps_proj(timestep)
112134
temb = self.time_embedder(timestep)
113-
breakpoint()
135+
136+
timestep_proj = self.time_proj(self.act_fn(temb))
137+
138+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
139+
if encoder_hidden_states_image is not None:
140+
raise NotImplementedError("currently img2vid is not supported")
141+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
114142

115143

116144

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from absl.testing import absltest
2121
from flax import nnx
2222

23-
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed
24-
from ..models.embeddings_flax import NNXTimestepEmbedding
23+
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding
24+
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
2525

2626
class WanTransformerTest(unittest.TestCase):
2727
def setUp(self):
@@ -42,7 +42,19 @@ def test_rotary_pos_embed(self):
4242
)
4343
dummy_output = wan_rot_embed(dummy_hidden_states)
4444
assert dummy_output.shape == (1, 1, 75600, 64)
45-
45+
46+
def test_nnx_pixart_alpha_text_projection(self):
47+
key = jax.random.key(0)
48+
rngs = nnx.Rngs(key)
49+
dummy_caption = jnp.ones((1, 512, 4096))
50+
layer = NNXPixArtAlphaTextProjection(
51+
rngs=rngs,
52+
in_features=4096,
53+
hidden_size=5120
54+
)
55+
dummy_output = layer(dummy_caption)
56+
dummy_output.shape == (1, 512, 5120)
57+
4658
def test_nnx_timestep_embedding(self):
4759
key = jax.random.key(0)
4860
rngs = nnx.Rngs(key)
@@ -56,5 +68,30 @@ def test_nnx_timestep_embedding(self):
5668
dummy_output = layer(dummy_sample)
5769
assert dummy_output.shape == (1, 5120)
5870

71+
def test_wan_time_text_embedding(self):
72+
key = jax.random.key(0)
73+
rngs = nnx.Rngs(key)
74+
batch_size = 1
75+
dim=5120
76+
time_freq_dim=256
77+
time_proj_dim=30720
78+
text_embed_dim=4096
79+
layer = WanTimeTextImageEmbedding(
80+
rngs=rngs,
81+
dim=dim,
82+
time_freq_dim=time_freq_dim,
83+
time_proj_dim=time_proj_dim,
84+
text_embed_dim=text_embed_dim
85+
)
86+
87+
dummy_timestep = jnp.ones(batch_size)
88+
89+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
90+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
91+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(dummy_timestep, dummy_encoder_hidden_states)
92+
assert temb.shape == (batch_size, dim)
93+
assert timestep_proj.shape == (batch_size, time_proj_dim)
94+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
95+
5996
if __name__ == "__main__":
6097
absltest.main()

0 commit comments

Comments
 (0)