Skip to content

Commit 064fc5f

Browse files
add nnx classes for timestep embeddings and timesteps.
1 parent d449d1f commit 064fc5f

5 files changed

Lines changed: 229 additions & 38 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
15+
from typing import Optional
1616
import flax.linen as nn
17+
from flax import nnx
1718
import jax.numpy as jnp
1819
from typing import List, Union
1920
import jax
21+
from .modeling_flax_utils import get_activation
2022

2123

2224
def get_sinusoidal_embeddings(
@@ -56,6 +58,86 @@ def get_sinusoidal_embeddings(
5658
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
5759
return signal
5860

61+
class NNXTimestepEmbedding(nnx.Module):
62+
r"""
63+
Time step Embedding Module. Learns embeddings for input time steps.
64+
65+
Args:
66+
time_embed_dim (`int`, *optional*, defaults to `32`):
67+
Time step embedding dimension
68+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
69+
Parameters `dtype`
70+
"""
71+
def __init__(
72+
self,
73+
rngs: nnx.Rngs,
74+
in_channels: int,
75+
time_embed_dim: int = 32,
76+
act_fn: str = "silu",
77+
out_dim: int = None,
78+
post_act_fn: Optional[str] = None,
79+
cond_proj_dim: int = None,
80+
sample_proj_bias=True,
81+
dtype: jnp.dtype = jnp.float32,
82+
weights_dtype: jnp.dtype = jnp.float32,
83+
precision: jax.lax.Precision = None,
84+
):
85+
self.linear_1 = nnx.Linear(
86+
rngs=rngs,
87+
in_features=in_channels,
88+
out_features=time_embed_dim,
89+
use_bias=sample_proj_bias,
90+
dtype=dtype,
91+
param_dtype=weights_dtype,
92+
precision=precision,
93+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
94+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
95+
)
96+
97+
if cond_proj_dim is not None:
98+
self.cond_proj = nnx.Linear(
99+
rngs=rngs,
100+
)
101+
else:
102+
self.cond_proj = None
103+
104+
self.act = get_activation(act_fn)
105+
106+
if out_dim is not None:
107+
time_embed_dim_out = out_dim
108+
else:
109+
time_embed_dim_out = time_embed_dim
110+
111+
self.linear_2 = nnx.Linear(
112+
rngs=rngs,
113+
in_features=time_embed_dim,
114+
out_features=time_embed_dim_out,
115+
use_bias=sample_proj_bias,
116+
dtype=dtype,
117+
param_dtype=weights_dtype,
118+
precision=precision,
119+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)),
120+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
121+
)
122+
123+
if post_act_fn is None:
124+
self.post_act = None
125+
else:
126+
self.post_act = get_activation(post_act_fn)
127+
128+
def __call__(self, sample, condition=None):
129+
if condition is not None:
130+
sample = sample + self.cond_proj(condition)
131+
sample = self.linear_1(sample)
132+
133+
if self.act is not None:
134+
sample = self.act(sample)
135+
sample = self.linear_2(sample)
136+
137+
if self.post_act is not None:
138+
sample = self.post_act(sample)
139+
return sample
140+
59141

60142
class FlaxTimestepEmbedding(nn.Module):
61143
r"""
@@ -79,6 +161,23 @@ def __call__(self, temb):
79161
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb)
80162
return temb
81163

164+
class NNXFlaxTimesteps(nnx.Module):
165+
def __init__(
166+
self,
167+
dim: int = 32,
168+
flip_sin_to_cos: bool = False,
169+
freq_shift: float = 1.0,
170+
scale: int = 1,
171+
):
172+
self.dim = dim
173+
self.flip_sin_to_cos = flip_sin_to_cos
174+
self.freq_shift = freq_shift
175+
self.scale = scale
176+
177+
def __call__(self, timesteps):
178+
return get_sinusoidal_embeddings(
179+
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
180+
)
82181

83182
class FlaxTimesteps(nn.Module):
84183
r"""

src/maxdiffusion/models/modeling_flax_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343

4444
logger = logging.get_logger(__name__)
4545

46+
_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish}
47+
48+
def get_activation(name: str):
49+
func = _ACTIVATIONS.get(name)
50+
if func is None:
51+
raise ValueError(f"Unknown activation function: {name}")
52+
return func
4653

4754
class FlaxModelMixin(PushToHubMixin):
4855
r"""

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,14 @@
2020
import jax.numpy as jnp
2121
from flax import nnx
2222
from ...configuration_utils import ConfigMixin
23-
from ..modeling_flax_utils import FlaxModelMixin
23+
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2424
from ... import common_types
2525
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
2626

2727
BlockSizes = common_types.BlockSizes
2828

2929
CACHE_T = 2
3030

31-
_ACTIVATIONS = {"swish": jax.nn.silu, "silu": jax.nn.silu, "relu": jax.nn.relu, "gelu": jax.nn.gelu, "mish": jax.nn.mish}
32-
33-
34-
def get_activation(name: str):
35-
func = _ACTIVATIONS.get(name)
36-
if func is None:
37-
raise ValueError(f"Unknown activation function: {name}")
38-
return func
39-
4031

4132
# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
4233
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Tuple, Optional
17+
from typing import Tuple, Optional, Dict, Union, Any
1818
import jax
1919
import jax.numpy as jnp
2020
from flax import nnx
2121
from .... import common_types, max_logging
2222
from ...modeling_flax_utils import FlaxModelMixin
23-
from ....configuration_utils import ConfigMixin
24-
from ...embeddings_flax import get_1d_rotary_pos_embed
23+
from ....configuration_utils import ConfigMixin, register_to_config
24+
from ...embeddings_flax import get_1d_rotary_pos_embed, NNXFlaxTimesteps, NNXTimestepEmbedding
2525

2626
BlockSizes = common_types.BlockSizes
2727

@@ -65,7 +65,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
6565
cumulative_sizes = jnp.cumsum(jnp.array(sizes))
6666
split_indices = cumulative_sizes[:-1]
6767
freqs_split = jnp.split(self.freqs, split_indices, axis=1)
68-
68+
6969
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
7070
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
7171

@@ -80,6 +80,40 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
8080
return freqs_final
8181

8282

83+
class WanTimeTextImageEmbedding(nnx.Module):
84+
def __init__(
85+
self,
86+
rngs: nnx.Rngs,
87+
dim: int,
88+
time_freq_dim: int,
89+
time_proj_dim: int,
90+
text_embed_dim: int,
91+
image_embed_dim: Optional[int] = None,
92+
pos_embed_seq_len: Optional[int] = None,
93+
dtype: jnp.dtype = jnp.float32,
94+
weights_dtype: jnp.dtype = jnp.float32,
95+
precision: jax.lax.Precision = None,
96+
):
97+
self.timesteps_proj = NNXFlaxTimesteps(
98+
dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0
99+
)
100+
self.time_embedder = NNXTimestepEmbedding(
101+
rngs=rngs, in_channels=time_freq_dim, time_embed_dim=dim,
102+
dtype=dtype, weights_dtype=weights_dtype, precision=precision
103+
)
104+
105+
def __call__(
106+
self,
107+
timestep: jax.Array,
108+
encoder_hidden_states: jax.Array,
109+
encoder_hidden_states_image: Optional[jax.Array] = None
110+
):
111+
timestep = self.timesteps_proj(timestep)
112+
temb = self.time_embedder(timestep)
113+
breakpoint()
114+
115+
116+
83117
class WanTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin):
84118
def __init__(
85119
self,
@@ -120,25 +154,28 @@ def __init__(
120154

121155

122156
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
123-
157+
158+
@register_to_config
124159
def __init__(
125160
self,
126161
rngs: nnx.Rngs,
127162
model_type="t2v",
128-
patch_size=(1, 2, 2),
129-
text_len=512,
130-
in_dim=16,
131-
dim=2048,
132-
ffn_dim=8192,
133-
freq_dim=256,
134-
text_dim=4096,
135-
out_dim=16,
136-
num_heads=16,
137-
num_layers=32,
138-
window_size=(-1, -1),
139-
qk_norm=True,
140-
cross_attn_norm=True,
141-
eps=1e-6,
163+
patch_size: Tuple[int] = (1, 2, 2),
164+
num_attention_heads: int = 40,
165+
attention_head_dim: int = 128,
166+
in_channels: int = 16,
167+
out_channels: int = 16,
168+
text_dim: int = 4096,
169+
freq_dim: int = 256,
170+
ffn_dim: int = 13824,
171+
num_layers: int = 40,
172+
cross_attn_norm: bool = True,
173+
qk_norm: Optional[str] = "rms_norm_across_heads",
174+
eps: float = 1e-6,
175+
image_dim: Optional[int] = None,
176+
added_kn_proj_dim: Optional[int] = None,
177+
rope_max_seq_len: int = 1024,
178+
pos_embed_seq_len: Optional[int] = None,
142179
flash_min_seq_length: int = 4096,
143180
flash_block_sizes: BlockSizes = None,
144181
mesh: jax.sharding.Mesh = None,
@@ -147,18 +184,62 @@ def __init__(
147184
precision: jax.lax.Precision = None,
148185
attention: str = "dot_product",
149186
):
150-
self.path_embedding = nnx.Conv(
151-
in_dim,
152-
dim,
187+
188+
inner_dim = num_attention_heads * attention_head_dim
189+
out_channels = out_channels or in_channels
190+
191+
#1. Patch & position embedding
192+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
193+
self.patch_embedding = nnx.Conv(
194+
in_channels,
195+
inner_dim,
196+
rngs=rngs,
153197
kernel_size=patch_size,
154198
strides=patch_size,
155199
dtype=dtype,
156200
param_dtype=weights_dtype,
157201
precision=precision,
158202
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)),
159-
rngs=rngs,
160203
)
161204

162-
def __call__(self, x):
163-
x = self.path_embedding(x)
164-
return x
205+
# 2. Condition embeddings
206+
# image_embedding_dim=1280 for I2V model
207+
self.condition_embedder = WanTimeTextImageEmbedding(
208+
rngs=rngs,
209+
dim=inner_dim,
210+
time_freq_dim=freq_dim,
211+
time_proj_dim=inner_dim * 6,
212+
text_embed_dim=text_dim,
213+
image_embed_dim=image_dim,
214+
pos_embed_seq_len=pos_embed_seq_len
215+
)
216+
217+
def __call__(
218+
self,
219+
hidden_states: jax.Array,
220+
timestep: jax.Array,
221+
encoder_hidden_states: jax.Array,
222+
encoder_hidden_states_image: Optional[jax.Array] = None,
223+
return_dict: bool = True,
224+
attention_kwargs: Optional[Dict[str, Any]] = None,
225+
) -> Union[jax.Array, Dict[str, jax.Array]]:
226+
batch_size, num_frames, height, width, num_channels = hidden_states.shape
227+
p_t, p_h, p_w = self.config.patch_size
228+
post_patch_num_frames = num_frames // p_t
229+
post_patch_height = height // p_h
230+
post_patch_width = width // p_w
231+
232+
233+
rotary_emb = self.rope(hidden_states)
234+
hidden_states = self.patch_embedding(hidden_states)
235+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
236+
237+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
238+
timestep, encoder_hidden_states, encoder_hidden_states_image
239+
)
240+
#hidden_states =
241+
# Torch shape: ([1, 5120, 21, 45, 80])
242+
# Jax shape: (1, 21, 45, 80, 5120) so channels is 5120
243+
244+
245+
return hidden_states

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from flax import nnx
2222

2323
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed
24+
from ..models.embeddings_flax import NNXTimestepEmbedding
2425

2526
class WanTransformerTest(unittest.TestCase):
2627
def setUp(self):
@@ -41,7 +42,19 @@ def test_rotary_pos_embed(self):
4142
)
4243
dummy_output = wan_rot_embed(dummy_hidden_states)
4344
assert dummy_output.shape == (1, 1, 75600, 64)
44-
# output shape should be torch.Size([1, 1, 75600, 64])
45+
46+
def test_nnx_timestep_embedding(self):
47+
key = jax.random.key(0)
48+
rngs = nnx.Rngs(key)
49+
50+
dummy_sample = jnp.ones((1, 256))
51+
layer = NNXTimestepEmbedding(
52+
rngs=rngs,
53+
in_channels=256,
54+
time_embed_dim=5120
55+
)
56+
dummy_output = layer(dummy_sample)
57+
assert dummy_output.shape == (1, 5120)
4558

4659
if __name__ == "__main__":
4760
absltest.main()

0 commit comments

Comments
 (0)