Skip to content

Commit f78d7e1

Browse files
committed
Add Wan-Animate inference pipeline, image processor, and transformer sharding
- wan_pipeline_animate.py: Full JAX/Flax port of diffusers WanAnimatePipeline supporting both "animate" (pose+face) and "replace" (background+mask) modes, segmented inference with overlap conditioning, and optional CFG. - image_processor.py: WanAnimateImageProcessor with letterbox resize (vs center-crop) and vae_scale_factor*spatial_patch_size-aligned dimensions for the reference character image. - transformer_wan_animate.py: Added nnx.with_partitioning sharding annotations to patch_embedding, pose_patch_embedding, proj_out, scale_shift_table, and all FlaxWanAnimateFaceBlockCrossAttention projections; added nn.with_logical_constraint on the input hidden_states. - wan_utils.py: Added load_wan_animate_transformer with motion-encoder-aware weight loading (skip weight→kernel rename and transpose for FlaxMotionConv2d/ FlaxMotionLinear; map activation.bias→act_fn.bias for FusedLeakyReLU).
1 parent 94f9c21 commit f78d7e1

6 files changed

Lines changed: 1311 additions & 12 deletions

File tree

src/maxdiffusion/models/wan/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
17+
from .transformer_wan_animate import NNXWanAnimateTransformer3DModel

src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import math
2020
import jax
2121
import jax.numpy as jnp
22+
import flax.linen as nn
2223
from flax import nnx
2324
from .... import common_types
2425
from ...modeling_flax_utils import FlaxModelMixin
@@ -81,6 +82,13 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
8182

8283

8384
class FlaxMotionConv2d(nnx.Module):
85+
"""2-D convolution with EqualizedLR scaling and optional FusedLeakyReLU.
86+
87+
Weights are stored in PyTorch OIHW format (out, in, k, k) as raw nnx.Param
88+
so that the weight-loading code in wan_utils.py can map them without
89+
transposing. No sharding annotations are applied because this module is
90+
part of the small motion encoder network.
91+
"""
8492

8593
def __init__(
8694
self,
@@ -123,7 +131,7 @@ def __init__(
123131
self.blur_kernel = None
124132

125133
key = rngs.params()
126-
# Shape: (out_channels, in_channels, kernel, kernel) mapping PyTorch 'OIHW'
134+
# Shape: (out_channels, in_channels, kernel, kernel) PyTorch OIHW format.
127135
self.weight = nnx.Param(jax.random.normal(key, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype))
128136
self.scale = 1.0 / math.sqrt(in_channels * kernel_size**2)
129137

@@ -156,7 +164,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
156164
x,
157165
expanded_kernel,
158166
window_strides=(1, 1),
159-
padding=[(pad_h, pad_h), (pad_w, pad_w)], # Corrected Symmetric Padding
167+
padding=[(pad_h, pad_h), (pad_w, pad_w)],
160168
dimension_numbers=("NCHW", "OIHW", "NCHW"),
161169
feature_group_count=self.in_channels,
162170
)
@@ -186,6 +194,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
186194

187195

188196
class FlaxMotionLinear(nnx.Module):
197+
"""Equalized-LR linear layer with optional FusedLeakyReLU.
198+
199+
Weights are stored in PyTorch (out, in) format as raw nnx.Param — same
200+
reason as FlaxMotionConv2d. No sharding annotations needed (small layer).
201+
"""
189202

190203
def __init__(
191204
self,
@@ -296,6 +309,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
296309

297310

298311
class FlaxWanAnimateMotionEncoder(nnx.Module):
312+
"""Encodes a face video frame into a motion vector.
313+
314+
All weights in this network are small (the largest is 32×512→16) so
315+
sharding annotations are not applied.
316+
"""
299317

300318
def __init__(
301319
self,
@@ -395,7 +413,6 @@ def __init__(
395413

396414
self.act = jax.nn.silu
397415

398-
# Added explicit padding="VALID" to exactly mirror PyTorch's padding=0 default
399416
self.conv1_local = nnx.Conv(
400417
in_dim,
401418
hidden_dim * num_heads,
@@ -449,7 +466,15 @@ def __init__(
449466
dtype=dtype,
450467
)
451468

452-
self.out_proj = nnx.Linear(hidden_dim, out_dim, rngs=rngs, dtype=dtype)
469+
# hidden_dim (mlp) → out_dim (embed): ("mlp", "embed")
470+
self.out_proj = nnx.Linear(
471+
hidden_dim,
472+
out_dim,
473+
rngs=rngs,
474+
dtype=dtype,
475+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
476+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
477+
)
453478

454479
self.padding_tokens = nnx.Param(jnp.zeros((1, 1, 1, out_dim), dtype=dtype))
455480

@@ -510,11 +535,45 @@ def __init__(
510535
self.pre_norm_q = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype)
511536
self.pre_norm_kv = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype)
512537

513-
self.to_q = nnx.Linear(dim, self.inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype)
514-
self.to_k = nnx.Linear(dim, self.kv_inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype)
515-
self.to_v = nnx.Linear(dim, self.kv_inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype)
538+
# embed → heads
539+
self.to_q = nnx.Linear(
540+
dim,
541+
self.inner_dim,
542+
use_bias=use_bias,
543+
rngs=rngs,
544+
dtype=dtype,
545+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
546+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
547+
)
548+
self.to_k = nnx.Linear(
549+
dim,
550+
self.kv_inner_dim,
551+
use_bias=use_bias,
552+
rngs=rngs,
553+
dtype=dtype,
554+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
555+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
556+
)
557+
self.to_v = nnx.Linear(
558+
dim,
559+
self.kv_inner_dim,
560+
use_bias=use_bias,
561+
rngs=rngs,
562+
dtype=dtype,
563+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
564+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
565+
)
516566

517-
self.to_out = nnx.Linear(self.inner_dim, dim, use_bias=use_bias, rngs=rngs, dtype=dtype)
567+
# heads → embed
568+
self.to_out = nnx.Linear(
569+
self.inner_dim,
570+
dim,
571+
use_bias=use_bias,
572+
rngs=rngs,
573+
dtype=dtype,
574+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
575+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
576+
)
518577

519578
self.norm_q = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype)
520579
self.norm_k = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype)
@@ -544,14 +603,14 @@ def __call__(
544603

545604
query_S = query.shape[1]
546605

547-
# Prepare for attention by folding Time into the Batch dimension
606+
# Fold Time into the Batch dimension for attention
548607
query = jnp.reshape(query, (B * T, query_S // T, self.heads, -1))
549608
key = jnp.reshape(key, (B * T, N, self.heads, -1))
550609
value = jnp.reshape(value, (B * T, N, self.heads, -1))
551610

552611
attn_output = jax.nn.dot_product_attention(query, key, value)
553612

554-
# Collapse Time, Seq Length, and Heads straight back to (Batch, Total Sequence, Dim)
613+
# Restore (Batch, Total Sequence, Dim)
555614
attn_output = jnp.reshape(attn_output, (B, query_S, -1))
556615

557616
hidden_states = self.to_out(attn_output)
@@ -624,6 +683,8 @@ def __init__(
624683
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
625684

626685
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
686+
687+
# Patch embeddings — shard output (conv_out) axis across model parallelism.
627688
self.patch_embedding = nnx.Conv(
628689
in_channels,
629690
inner_dim,
@@ -632,6 +693,10 @@ def __init__(
632693
rngs=rngs,
633694
dtype=dtype,
634695
param_dtype=weights_dtype,
696+
kernel_init=nnx.with_partitioning(
697+
nnx.initializers.xavier_uniform(),
698+
(None, None, None, None, "conv_out"),
699+
),
635700
)
636701
self.pose_patch_embedding = nnx.Conv(
637702
latent_channels,
@@ -641,6 +706,10 @@ def __init__(
641706
rngs=rngs,
642707
dtype=dtype,
643708
param_dtype=weights_dtype,
709+
kernel_init=nnx.with_partitioning(
710+
nnx.initializers.xavier_uniform(),
711+
(None, None, None, None, "conv_out"),
712+
),
644713
)
645714

646715
self.condition_embedder = WanTimeTextImageEmbedding(
@@ -714,15 +783,22 @@ def __init__(
714783
self.face_adapter = nnx.List(face_adapters)
715784

716785
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
786+
787+
# Final projection — embed → output tokens.
717788
self.proj_out = nnx.Linear(
718789
rngs=rngs,
719790
in_features=inner_dim,
720791
out_features=out_channels * math.prod(patch_size),
721792
dtype=dtype,
722793
param_dtype=weights_dtype,
794+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)),
723795
)
796+
724797
key = rngs.params()
725-
self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim), dtype=dtype) / inner_dim**0.5)
798+
self.scale_shift_table = nnx.Param(
799+
jax.random.normal(key, (1, 2, inner_dim), dtype=dtype) / inner_dim**0.5,
800+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
801+
)
726802

727803
def conditional_named_scope(self, name: str):
728804
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
@@ -747,6 +823,9 @@ def __call__(
747823
f"Pose frames + 1 ({pose_hidden_states.shape[2]} + 1) must equal hidden_states frames ({hidden_states.shape[2]})"
748824
)
749825

826+
# Constrain input to batch-sharded layout before any computation.
827+
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
828+
750829
batch_size, num_channels, num_frames, height, width = hidden_states.shape
751830
p_t, p_h, p_w = self.patch_size
752831
post_patch_num_frames = num_frames // p_t
@@ -850,7 +929,7 @@ def encode_chunk_fn(carry, chunk):
850929
rngs,
851930
)
852931

853-
# Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...)
932+
# Face adapter integration: apply after every inject_face_latents_blocks-th block
854933
if motion_vec is not None and block_idx % self.inject_face_latents_blocks == 0:
855934
face_adapter_block_idx = block_idx // self.inject_face_latents_blocks
856935
face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,144 @@ def load_base_wan_transformer(
300300
return flax_state_dict
301301

302302

303+
def _is_motion_encoder_custom_weight(pt_key: str) -> bool:
304+
"""Returns True for FlaxMotionConv2d/FlaxMotionLinear weight keys that must NOT be renamed to kernel."""
305+
prefixes = (
306+
"motion_encoder.conv_in.",
307+
"motion_encoder.conv_out.",
308+
)
309+
if any(pt_key.startswith(p) for p in prefixes) and pt_key.endswith(".weight"):
310+
return True
311+
if "motion_encoder.res_blocks." in pt_key and pt_key.endswith(".weight"):
312+
return True
313+
if "motion_encoder.motion_network." in pt_key and pt_key.endswith(".weight"):
314+
return True
315+
return False
316+
317+
318+
def load_wan_animate_transformer(
319+
pretrained_model_name_or_path: str,
320+
eval_shapes: dict,
321+
device: str,
322+
hf_download: bool = True,
323+
num_layers: int = 40,
324+
scan_layers: bool = True,
325+
subfolder: str = "transformer",
326+
):
327+
"""Loads WanAnimate transformer weights from a HuggingFace checkpoint.
328+
329+
Handles the additional key mappings for:
330+
- pose_patch_embedding (nnx.Conv3d → kernel)
331+
- motion_encoder.* (FlaxMotionConv2d/FlaxMotionLinear → keep as 'weight', no transpose)
332+
- activation.bias → act_fn.bias (FusedLeakyReLU bias remapping)
333+
- face_encoder.* (nnx.Conv/Linear → standard rename to kernel)
334+
- face_adapter.* (nnx.Linear → standard rename to kernel)
335+
"""
336+
device = jax.local_devices(backend=device)[0]
337+
filename = "diffusion_pytorch_model.safetensors.index.json"
338+
local_files = False
339+
if os.path.isdir(pretrained_model_name_or_path):
340+
index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
341+
if not os.path.isfile(index_file_path):
342+
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
343+
local_files = True
344+
elif hf_download:
345+
index_file_path = hf_hub_download(
346+
pretrained_model_name_or_path,
347+
subfolder=subfolder,
348+
filename=filename,
349+
)
350+
with jax.default_device(device):
351+
with open(index_file_path, "r") as f:
352+
index_dict = json.load(f)
353+
model_files = set()
354+
for key in index_dict["weight_map"].keys():
355+
model_files.add(index_dict["weight_map"][key])
356+
357+
model_files = list(model_files)
358+
tensors = {}
359+
for model_file in model_files:
360+
if local_files:
361+
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
362+
else:
363+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
364+
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
365+
if ckpt_shard_path is not None:
366+
with safe_open(ckpt_shard_path, framework="pt") as f:
367+
for k in f.keys():
368+
tensors[k] = torch2jax(f.get_tensor(k))
369+
370+
flax_state_dict = {}
371+
cpu = jax.local_devices(backend="cpu")[0]
372+
flattened_dict = flatten_dict(eval_shapes)
373+
random_flax_state_dict = {}
374+
for key in flattened_dict:
375+
string_tuple = tuple([str(item) for item in key])
376+
random_flax_state_dict[string_tuple] = flattened_dict[key]
377+
del flattened_dict
378+
379+
for pt_key, tensor in tensors.items():
380+
if "norm_added_q" in pt_key:
381+
continue
382+
383+
renamed_pt_key = rename_key(pt_key)
384+
385+
# --- Standard WAN transformer renames (shared with base transformer) ---
386+
if "condition_embedder" in renamed_pt_key:
387+
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1")
388+
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2")
389+
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj")
390+
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1")
391+
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2")
392+
393+
if "image_embedder" in renamed_pt_key:
394+
if "net.0.proj" in renamed_pt_key:
395+
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
396+
elif "net_0.proj" in renamed_pt_key:
397+
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
398+
if "net.2" in renamed_pt_key:
399+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
400+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
401+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
402+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
403+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
404+
405+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
406+
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
407+
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
408+
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
409+
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
410+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
411+
412+
# --- Animate-specific renames ---
413+
# FusedLeakyReLU bias: HuggingFace stores it under "activation.bias",
414+
# JAX stores it under "act_fn.bias" within FlaxMotionConv2d/FlaxMotionLinear.
415+
renamed_pt_key = renamed_pt_key.replace(".activation.bias", ".act_fn.bias")
416+
417+
# face_adapter cross-attention: norm_q/norm_k scale renaming
418+
# (rename_for_nnx handles norm_k/norm_q -> scale in get_key_and_value)
419+
420+
pt_tuple_key = tuple(renamed_pt_key.split("."))
421+
422+
# FlaxMotionConv2d and FlaxMotionLinear store weights as nnx.Param in PyTorch
423+
# OIHW / (out, in) format — do NOT rename weight→kernel or transpose.
424+
if _is_motion_encoder_custom_weight(renamed_pt_key):
425+
flax_key = _tuple_str_to_int(pt_tuple_key)
426+
flax_tensor = tensor
427+
else:
428+
flax_key, flax_tensor = get_key_and_value(
429+
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
430+
)
431+
432+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
433+
434+
validate_flax_state_dict(eval_shapes, flax_state_dict)
435+
flax_state_dict = unflatten_dict(flax_state_dict)
436+
del tensors
437+
jax.clear_caches()
438+
return flax_state_dict
439+
440+
303441
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
304442
device = jax.devices(device)[0]
305443
subfolder = "vae"

src/maxdiffusion/pipelines/wan/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17+
from .image_processor import WanAnimateImageProcessor
1718
from .wan_pipeline import WanPipeline
19+
from .wan_pipeline_animate import WanAnimatePipeline

0 commit comments

Comments
 (0)