Skip to content

Commit 440f39c

Browse files
wan transformer with in/out shapes verified
1 parent 4c00085 commit 440f39c

2 files changed

Lines changed: 71 additions & 5 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
"""
1616

1717
from typing import Tuple, Optional, Dict, Union, Any
18+
import math
1819
import jax
1920
import jax.numpy as jnp
2021
from flax import nnx
21-
from .... import common_types, max_logging
22+
from .... import common_types
2223
from ...modeling_flax_utils import FlaxModelMixin, get_activation
2324
from ....configuration_utils import ConfigMixin, register_to_config
2425
from ...embeddings_flax import (
@@ -447,7 +448,7 @@ def __init__(
447448
rngs=rngs,
448449
dim=inner_dim,
449450
ffn_dim=ffn_dim,
450-
num_attention_heads=num_attention_heads,
451+
num_heads=num_attention_heads,
451452
qk_norm=qk_norm,
452453
cross_attn_norm=cross_attn_norm,
453454
eps=eps,
@@ -462,6 +463,20 @@ def __init__(
462463
blocks.append(block)
463464
self.blocks = blocks
464465

466+
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
467+
self.proj_out = nnx.Linear(
468+
rngs=rngs,
469+
in_features=inner_dim,
470+
out_features=out_channels * math.prod(patch_size),
471+
dtype=dtype,
472+
param_dtype=weights_dtype,
473+
precision=precision,
474+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
475+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
476+
)
477+
key = rngs.params()
478+
self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5)
479+
465480
def __call__(
466481
self,
467482
hidden_states: jax.Array,
@@ -492,7 +507,14 @@ def __call__(
492507

493508
for block in self.blocks:
494509
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
495-
breakpoint()
510+
511+
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
512+
513+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
514+
hidden_states = self.proj_out(hidden_states)
496515

516+
# TODO - can this reshape happen in a single command?
517+
hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1)
518+
hidden_states = hidden_states.reshape(batch_size, num_frames, height, width, num_channels)
497519

498520
return hidden_states

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
create_device_mesh,
2828
get_flash_block_sizes
2929
)
30-
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
30+
from ..models.wan.transformers.transformer_wan import (
31+
WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock, WanModel
32+
)
3133
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
3234
from ..models.normalization_flax import FP32LayerNorm
3335
from ..models.attention_flax import FlaxWanAttention
@@ -256,7 +258,49 @@ def test_wan_attention(self):
256258
except NotImplementedError as e:
257259
pass
258260

259-
261+
def test_wan_model(self):
262+
pyconfig.initialize(
263+
[
264+
None,
265+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
266+
],
267+
unittest=True
268+
)
269+
config = pyconfig.config
270+
271+
batch_size = 1
272+
channels = 16
273+
frames = 21
274+
height = 90
275+
width = 160
276+
hidden_states_shape = (batch_size, frames, height, width, channels)
277+
dummy_hidden_states = jnp.ones(hidden_states_shape)
278+
279+
key = jax.random.key(0)
280+
rngs = nnx.Rngs(key)
281+
devices_array = create_device_mesh(config)
282+
283+
flash_block_sizes = get_flash_block_sizes(config)
284+
285+
mesh = Mesh(devices_array, config.mesh_axes)
286+
batch_size = 1
287+
query_dim = 5120
288+
wan_model = WanModel(
289+
rngs=rngs,
290+
attention="flash",
291+
mesh=mesh,
292+
flash_block_sizes=flash_block_sizes,
293+
)
294+
295+
dummy_timestep = jnp.ones((batch_size))
296+
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))
297+
298+
dummy_output = wan_model(
299+
hidden_states=dummy_hidden_states,
300+
timestep=dummy_timestep,
301+
encoder_hidden_states=dummy_encoder_hidden_states
302+
)
303+
assert dummy_output.shape == hidden_states_shape
260304

261305
if __name__ == "__main__":
262306
absltest.main()

0 commit comments

Comments
 (0)