Skip to content

Commit 4c00085

Browse files
add transformer block
1 parent 1abc00c commit 4c00085

3 files changed

Lines changed: 274 additions & 14 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,11 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
697697
def __call__(
698698
self,
699699
hidden_states: jax.Array,
700-
encoder_hidden_states: jax.Array,
700+
encoder_hidden_states: jax.Array = None,
701701
rotary_emb: Optional[jax.Array] = None
702702
) -> jax.Array:
703+
703704
dtype = hidden_states.dtype
704-
# batch_size = hidden_states.shape[0]
705705
if encoder_hidden_states is None:
706706
encoder_hidden_states = hidden_states
707707
query_proj = self.query(hidden_states)
@@ -715,12 +715,14 @@ def __call__(
715715
if self.qk_norm:
716716
query_proj = self.query_norm(query_proj)
717717
key_proj = self.key_norm(key_proj)
718-
query_proj = _unflatten_heads(query_proj, self.heads)
719-
key_proj = _unflatten_heads(key_proj, self.heads)
718+
720719
if rotary_emb is not None:
720+
query_proj = _unflatten_heads(query_proj, self.heads)
721+
key_proj = _unflatten_heads(key_proj, self.heads)
721722
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
722-
query_proj = _reshape_heads_to_head_dim(query_proj)
723-
key_proj = _reshape_heads_to_head_dim(key_proj)
723+
query_proj = _reshape_heads_to_head_dim(query_proj)
724+
key_proj = _reshape_heads_to_head_dim(key_proj)
725+
724726
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
725727
attn_output = attn_output.astype(dtype=dtype)
726728

@@ -1309,7 +1311,6 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k
13091311
hidden_states = hidden_states + residual
13101312
return self.dropout_layer(hidden_states, deterministic=deterministic)
13111313

1312-
13131314
class FlaxFeedForward(nn.Module):
13141315
r"""
13151316
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 196 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
NNXPixArtAlphaTextProjection
2929
)
3030
from ...normalization_flax import FP32LayerNorm
31+
from ...attention_flax import FlaxWanAttention
3132

3233
BlockSizes = common_types.BlockSizes
3334

@@ -181,6 +182,89 @@ def __init__(
181182
rope_max_seq_len
182183
)
183184

185+
class ApproximateGELU(nnx.Module):
186+
r"""
187+
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
188+
[paper](https://arxiv.org/abs/1606.08415).
189+
"""
190+
def __init__(
191+
self,
192+
rngs: nnx.Rngs,
193+
dim_in: int,
194+
dim_out: int,
195+
bias: bool,
196+
dtype: jnp.dtype = jnp.float32,
197+
weights_dtype: jnp.dtype = jnp.float32,
198+
precision: jax.lax.Precision = None,
199+
):
200+
self.proj = nnx.Linear(
201+
rngs=rngs,
202+
in_features=dim_in,
203+
out_features=dim_out,
204+
use_bias=bias,
205+
dtype=dtype,
206+
param_dtype=weights_dtype,
207+
precision=precision,
208+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
209+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
210+
)
211+
212+
def __call__(self, x: jax.Array) -> jax.Array:
213+
x = self.proj(x)
214+
return x * jax.nn.sigmoid(1.702 * x)
215+
216+
217+
class WanFeedForward(nnx.Module):
218+
def __init__(
219+
self,
220+
rngs: nnx.Rngs,
221+
dim: int,
222+
dim_out: Optional[int] = None,
223+
mult: int = 4,
224+
dropout: float = 0.0,
225+
activation_fn: str = "geglu",
226+
final_dropout: bool = False,
227+
inner_dim: int = None,
228+
bias: bool = True,
229+
dtype: jnp.dtype = jnp.float32,
230+
weights_dtype: jnp.dtype = jnp.float32,
231+
precision: jax.lax.Precision = None,
232+
):
233+
if inner_dim is None:
234+
inner_dim = int(dim * mult)
235+
dim_out = dim_out if dim_out is not None else dim
236+
237+
self.act_fn = None
238+
if activation_fn == "gelu-approximate":
239+
self.act_fn = ApproximateGELU(
240+
rngs=rngs,
241+
dim_in=dim,
242+
dim_out=inner_dim,
243+
bias=bias,
244+
dtype=dtype,
245+
weights_dtype=weights_dtype,
246+
precision=precision
247+
)
248+
else:
249+
raise NotImplementedError(f"{activation_fn} is not implemented.")
250+
251+
self.proj_out = nnx.Linear(
252+
rngs=rngs,
253+
in_features=inner_dim,
254+
out_features=dim_out,
255+
use_bias=bias,
256+
dtype=dtype,
257+
param_dtype=weights_dtype,
258+
precision=precision,
259+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)),
260+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
261+
)
262+
263+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
264+
hidden_states = self.act_fn(hidden_states)
265+
return self.proj_out(hidden_states)
266+
267+
184268

185269
class WanTransformerBlock(nnx.Module):
186270
def __init__(
@@ -192,17 +276,107 @@ def __init__(
192276
qk_norm: str = "rms_norm_across_heads",
193277
cross_attn_norm: bool = False,
194278
eps: float = 1e-6,
195-
added_kv_proj_dim: Optional[int] = None
279+
# In torch, this is none, so it can be ignored.
280+
# added_kv_proj_dim: Optional[int] = None,
281+
flash_min_seq_length: int = 4096,
282+
flash_block_sizes: BlockSizes = None,
283+
mesh: jax.sharding.Mesh = None,
284+
dtype: jnp.dtype = jnp.float32,
285+
weights_dtype: jnp.dtype = jnp.float32,
286+
precision: jax.lax.Precision = None,
287+
attention: str = "dot_product",
288+
196289
):
290+
291+
# 1. Self-attention
197292
self.norm1 = FP32LayerNorm(
293+
rngs=rngs,
198294
dim=dim,
199295
eps=eps,
200296
elementwise_affine=False
201297
)
298+
self.attn1 = FlaxWanAttention(
299+
rngs=rngs,
300+
query_dim=dim,
301+
heads=num_heads,
302+
dim_head= dim // num_heads,
303+
qk_norm=qk_norm,
304+
eps=eps,
305+
flash_min_seq_length=flash_min_seq_length,
306+
flash_block_sizes=flash_block_sizes,
307+
mesh=mesh,
308+
dtype=dtype,
309+
weights_dtype=weights_dtype,
310+
precision=precision,
311+
attention_kernel=attention
312+
)
313+
314+
# 1. Cross-attention
315+
self.attn2 = FlaxWanAttention(
316+
rngs=rngs,
317+
query_dim=dim,
318+
heads=num_heads,
319+
dim_head= dim // num_heads,
320+
qk_norm=qk_norm,
321+
eps=eps,
322+
flash_min_seq_length=flash_min_seq_length,
323+
flash_block_sizes=flash_block_sizes,
324+
mesh=mesh,
325+
dtype=dtype,
326+
weights_dtype=weights_dtype,
327+
precision=precision,
328+
attention_kernel=attention
329+
)
330+
assert cross_attn_norm == True
331+
self.norm2 = FP32LayerNorm(
332+
rngs=rngs,
333+
dim=dim,
334+
eps=eps,
335+
elementwise_affine=True
336+
)
337+
338+
# 3. Feed-forward
339+
self.ffn = WanFeedForward(
340+
rngs=rngs,
341+
dim=dim,
342+
inner_dim=ffn_dim,
343+
activation_fn="gelu-approximate",
344+
dtype=dtype,
345+
weights_dtype=weights_dtype,
346+
precision=precision
347+
)
348+
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
349+
350+
key = rngs.params()
351+
self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5)
202352

203-
def __call__(self):
204-
pass
353+
def __call__(
354+
self,
355+
hidden_states: jax.Array,
356+
encoder_hidden_states: jax.Array,
357+
temb: jax.Array,
358+
rotary_emb: jax.Array
359+
):
360+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
361+
(self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
362+
)
363+
364+
# 1. Self-attention
365+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
366+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
367+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
368+
369+
# 2. Cross-attention
370+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32))
371+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
372+
hidden_states = hidden_states + attn_output
373+
374+
# 3. Feed-forward
375+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
205376

377+
ff_output = self.ffn(norm_hidden_states)
378+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype)
379+
return hidden_states
206380

207381

208382
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -269,7 +443,22 @@ def __init__(
269443
# 3. Transformer blocks
270444
blocks = []
271445
for _ in range(num_layers):
272-
block = WanTransformerBlock()
446+
block = WanTransformerBlock(
447+
rngs=rngs,
448+
dim=inner_dim,
449+
ffn_dim=ffn_dim,
450+
num_attention_heads=num_attention_heads,
451+
qk_norm=qk_norm,
452+
cross_attn_norm=cross_attn_norm,
453+
eps=eps,
454+
flash_min_seq_length=flash_min_seq_length,
455+
flash_block_sizes=flash_block_sizes,
456+
mesh=mesh,
457+
dtype=dtype,
458+
weights_dtype=weights_dtype,
459+
precision=precision,
460+
attention=attention
461+
)
273462
blocks.append(block)
274463
self.blocks = blocks
275464

@@ -301,8 +490,9 @@ def __call__(
301490
if encoder_hidden_states_image is not None:
302491
raise NotImplementedError("img2vid is not yet implemented.")
303492

304-
# for block in self.blocks:
305-
493+
for block in self.blocks:
494+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
495+
breakpoint()
306496

307497

308498
return hidden_states

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
create_device_mesh,
2828
get_flash_block_sizes
2929
)
30-
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding
30+
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
3131
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
3232
from ..models.normalization_flax import FP32LayerNorm
3333
from ..models.attention_flax import FlaxWanAttention
@@ -119,6 +119,75 @@ def test_wan_time_text_embedding(self):
119119
assert timestep_proj.shape == (batch_size, time_proj_dim)
120120
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
121121

122+
def test_wan_block(self):
123+
key = jax.random.key(0)
124+
rngs = nnx.Rngs(key)
125+
pyconfig.initialize(
126+
[
127+
None,
128+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
129+
],
130+
unittest=True
131+
)
132+
config = pyconfig.config
133+
134+
devices_array = create_device_mesh(config)
135+
136+
flash_block_sizes = get_flash_block_sizes(config)
137+
138+
mesh = Mesh(devices_array, config.mesh_axes)
139+
140+
dim=5120
141+
ffn_dim=13824
142+
num_heads=40
143+
qk_norm="rms_norm_across_heads"
144+
cross_attn_norm=True
145+
eps=1e-6
146+
147+
batch_size = 1
148+
channels = 16
149+
frames = 21
150+
height = 90
151+
width = 160
152+
hidden_dim = 75600
153+
154+
# for rotary post embed.
155+
hidden_states_shape = (batch_size, frames, height, width, channels)
156+
dummy_hidden_states = jnp.ones(hidden_states_shape)
157+
158+
wan_rot_embed = WanRotaryPosEmbed(
159+
attention_head_dim=128,
160+
patch_size=[1, 2, 2],
161+
max_seq_len=1024
162+
)
163+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
164+
assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64)
165+
166+
# for transformer block
167+
dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim))
168+
169+
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
170+
171+
dummy_temb = jnp.ones((batch_size, 6, dim))
172+
173+
wan_block = WanTransformerBlock(
174+
rngs=rngs,
175+
dim=dim,
176+
ffn_dim=ffn_dim,
177+
num_heads=num_heads,
178+
qk_norm=qk_norm,
179+
cross_attn_norm=cross_attn_norm,
180+
eps=eps,
181+
attention="flash",
182+
mesh=mesh,
183+
flash_block_sizes=flash_block_sizes
184+
)
185+
186+
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
187+
assert dummy_output.shape == dummy_hidden_states.shape
188+
189+
190+
122191
def test_wan_attention(self):
123192
pyconfig.initialize(
124193
[

0 commit comments

Comments
 (0)