1919import math
2020import jax
2121import jax .numpy as jnp
22+ import flax .linen as nn
2223from flax import nnx
2324from .... import common_types
2425from ...modeling_flax_utils import FlaxModelMixin
@@ -81,6 +82,13 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
8182
8283
8384class FlaxMotionConv2d (nnx .Module ):
85+ """2-D convolution with EqualizedLR scaling and optional FusedLeakyReLU.
86+
87+ Weights are stored in PyTorch OIHW format (out, in, k, k) as raw nnx.Param
88+ so that the weight-loading code in wan_utils.py can map them without
89+ transposing. No sharding annotations are applied because this module is
90+ part of the small motion encoder network.
91+ """
8492
8593 def __init__ (
8694 self ,
@@ -123,7 +131,7 @@ def __init__(
123131 self .blur_kernel = None
124132
125133 key = rngs .params ()
126- # Shape: (out_channels, in_channels, kernel, kernel) mapping PyTorch ' OIHW'
134+ # Shape: (out_channels, in_channels, kernel, kernel) — PyTorch OIHW format.
127135 self .weight = nnx .Param (jax .random .normal (key , (out_channels , in_channels , kernel_size , kernel_size ), dtype = dtype ))
128136 self .scale = 1.0 / math .sqrt (in_channels * kernel_size ** 2 )
129137
@@ -156,7 +164,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
156164 x ,
157165 expanded_kernel ,
158166 window_strides = (1 , 1 ),
159- padding = [(pad_h , pad_h ), (pad_w , pad_w )], # Corrected Symmetric Padding
167+ padding = [(pad_h , pad_h ), (pad_w , pad_w )],
160168 dimension_numbers = ("NCHW" , "OIHW" , "NCHW" ),
161169 feature_group_count = self .in_channels ,
162170 )
@@ -186,6 +194,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
186194
187195
188196class FlaxMotionLinear (nnx .Module ):
197+ """Equalized-LR linear layer with optional FusedLeakyReLU.
198+
199+ Weights are stored in PyTorch (out, in) format as raw nnx.Param — same
200+ reason as FlaxMotionConv2d. No sharding annotations needed (small layer).
201+ """
189202
190203 def __init__ (
191204 self ,
@@ -296,6 +309,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
296309
297310
298311class FlaxWanAnimateMotionEncoder (nnx .Module ):
312+ """Encodes a face video frame into a motion vector.
313+
314+ All weights in this network are small (the largest is 32×512→16) so
315+ sharding annotations are not applied.
316+ """
299317
300318 def __init__ (
301319 self ,
@@ -395,7 +413,6 @@ def __init__(
395413
396414 self .act = jax .nn .silu
397415
398- # Added explicit padding="VALID" to exactly mirror PyTorch's padding=0 default
399416 self .conv1_local = nnx .Conv (
400417 in_dim ,
401418 hidden_dim * num_heads ,
@@ -449,7 +466,15 @@ def __init__(
449466 dtype = dtype ,
450467 )
451468
452- self .out_proj = nnx .Linear (hidden_dim , out_dim , rngs = rngs , dtype = dtype )
469+ # hidden_dim (mlp) → out_dim (embed): ("mlp", "embed")
470+ self .out_proj = nnx .Linear (
471+ hidden_dim ,
472+ out_dim ,
473+ rngs = rngs ,
474+ dtype = dtype ,
475+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("mlp" , "embed" )),
476+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
477+ )
453478
454479 self .padding_tokens = nnx .Param (jnp .zeros ((1 , 1 , 1 , out_dim ), dtype = dtype ))
455480
@@ -510,11 +535,45 @@ def __init__(
510535 self .pre_norm_q = nnx .LayerNorm (dim , epsilon = eps , use_bias = False , use_scale = False , rngs = rngs , dtype = dtype )
511536 self .pre_norm_kv = nnx .LayerNorm (dim , epsilon = eps , use_bias = False , use_scale = False , rngs = rngs , dtype = dtype )
512537
513- self .to_q = nnx .Linear (dim , self .inner_dim , use_bias = use_bias , rngs = rngs , dtype = dtype )
514- self .to_k = nnx .Linear (dim , self .kv_inner_dim , use_bias = use_bias , rngs = rngs , dtype = dtype )
515- self .to_v = nnx .Linear (dim , self .kv_inner_dim , use_bias = use_bias , rngs = rngs , dtype = dtype )
538+ # embed → heads
539+ self .to_q = nnx .Linear (
540+ dim ,
541+ self .inner_dim ,
542+ use_bias = use_bias ,
543+ rngs = rngs ,
544+ dtype = dtype ,
545+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "heads" )),
546+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
547+ )
548+ self .to_k = nnx .Linear (
549+ dim ,
550+ self .kv_inner_dim ,
551+ use_bias = use_bias ,
552+ rngs = rngs ,
553+ dtype = dtype ,
554+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "heads" )),
555+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
556+ )
557+ self .to_v = nnx .Linear (
558+ dim ,
559+ self .kv_inner_dim ,
560+ use_bias = use_bias ,
561+ rngs = rngs ,
562+ dtype = dtype ,
563+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "heads" )),
564+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
565+ )
516566
517- self .to_out = nnx .Linear (self .inner_dim , dim , use_bias = use_bias , rngs = rngs , dtype = dtype )
567+ # heads → embed
568+ self .to_out = nnx .Linear (
569+ self .inner_dim ,
570+ dim ,
571+ use_bias = use_bias ,
572+ rngs = rngs ,
573+ dtype = dtype ,
574+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" )),
575+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
576+ )
518577
519578 self .norm_q = nnx .RMSNorm (dim_head , epsilon = eps , use_scale = True , rngs = rngs , dtype = dtype )
520579 self .norm_k = nnx .RMSNorm (dim_head , epsilon = eps , use_scale = True , rngs = rngs , dtype = dtype )
@@ -544,14 +603,14 @@ def __call__(
544603
545604 query_S = query .shape [1 ]
546605
547- # Prepare for attention by folding Time into the Batch dimension
606+ # Fold Time into the Batch dimension for attention
548607 query = jnp .reshape (query , (B * T , query_S // T , self .heads , - 1 ))
549608 key = jnp .reshape (key , (B * T , N , self .heads , - 1 ))
550609 value = jnp .reshape (value , (B * T , N , self .heads , - 1 ))
551610
552611 attn_output = jax .nn .dot_product_attention (query , key , value )
553612
554- # Collapse Time, Seq Length, and Heads straight back to (Batch, Total Sequence, Dim)
613+ # Restore (Batch, Total Sequence, Dim)
555614 attn_output = jnp .reshape (attn_output , (B , query_S , - 1 ))
556615
557616 hidden_states = self .to_out (attn_output )
@@ -624,6 +683,8 @@ def __init__(
624683 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
625684
626685 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
686+
687+ # Patch embeddings — shard output (conv_out) axis across model parallelism.
627688 self .patch_embedding = nnx .Conv (
628689 in_channels ,
629690 inner_dim ,
@@ -632,6 +693,10 @@ def __init__(
632693 rngs = rngs ,
633694 dtype = dtype ,
634695 param_dtype = weights_dtype ,
696+ kernel_init = nnx .with_partitioning (
697+ nnx .initializers .xavier_uniform (),
698+ (None , None , None , None , "conv_out" ),
699+ ),
635700 )
636701 self .pose_patch_embedding = nnx .Conv (
637702 latent_channels ,
@@ -641,6 +706,10 @@ def __init__(
641706 rngs = rngs ,
642707 dtype = dtype ,
643708 param_dtype = weights_dtype ,
709+ kernel_init = nnx .with_partitioning (
710+ nnx .initializers .xavier_uniform (),
711+ (None , None , None , None , "conv_out" ),
712+ ),
644713 )
645714
646715 self .condition_embedder = WanTimeTextImageEmbedding (
@@ -714,15 +783,22 @@ def __init__(
714783 self .face_adapter = nnx .List (face_adapters )
715784
716785 self .norm_out = FP32LayerNorm (rngs = rngs , dim = inner_dim , eps = eps , elementwise_affine = False )
786+
787+ # Final projection — embed → output tokens.
717788 self .proj_out = nnx .Linear (
718789 rngs = rngs ,
719790 in_features = inner_dim ,
720791 out_features = out_channels * math .prod (patch_size ),
721792 dtype = dtype ,
722793 param_dtype = weights_dtype ,
794+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , None )),
723795 )
796+
724797 key = rngs .params ()
725- self .scale_shift_table = nnx .Param (jax .random .normal (key , (1 , 2 , inner_dim ), dtype = dtype ) / inner_dim ** 0.5 )
798+ self .scale_shift_table = nnx .Param (
799+ jax .random .normal (key , (1 , 2 , inner_dim ), dtype = dtype ) / inner_dim ** 0.5 ,
800+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
801+ )
726802
727803 def conditional_named_scope (self , name : str ):
728804 return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
@@ -747,6 +823,9 @@ def __call__(
747823 f"Pose frames + 1 ({ pose_hidden_states .shape [2 ]} + 1) must equal hidden_states frames ({ hidden_states .shape [2 ]} )"
748824 )
749825
826+ # Constrain input to batch-sharded layout before any computation.
827+ hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
828+
750829 batch_size , num_channels , num_frames , height , width = hidden_states .shape
751830 p_t , p_h , p_w = self .patch_size
752831 post_patch_num_frames = num_frames // p_t
@@ -850,7 +929,7 @@ def encode_chunk_fn(carry, chunk):
850929 rngs ,
851930 )
852931
853- # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...)
932+ # Face adapter integration: apply after every inject_face_latents_blocks-th block
854933 if motion_vec is not None and block_idx % self .inject_face_latents_blocks == 0 :
855934 face_adapter_block_idx = block_idx // self .inject_face_latents_blocks
856935 face_adapter_output = self .face_adapter [face_adapter_block_idx ](hidden_states , motion_vec )
0 commit comments