@@ -95,47 +95,37 @@ def __init__(
9595 )
9696
9797 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
98- current_padding = list (self ._causal_padding ) # Mutable copy
98+ current_padding = list (self ._causal_padding )
9999 padding_needed = self ._depth_padding_before
100100
101101 if cache_x is not None and padding_needed > 0 :
102- # Ensure cache has same spatial/channel dims, potentially different depth
103102 assert cache_x .shape [0 ] == x .shape [0 ] and cache_x .shape [2 :] == x .shape [2 :], "Cache spatial/channel dims mismatch"
104103 cache_len = cache_x .shape [1 ]
105- x = jnp .concatenate ([cache_x , x ], axis = 1 ) # Concat along depth (D)
106-
104+ x = jnp .concatenate ([cache_x , x ], axis = 1 )
107105 padding_needed -= cache_len
108106 if padding_needed < 0 :
109- # Cache longer than needed padding, trim from start
110107 x = x [:, - padding_needed :, ...]
111- current_padding [1 ] = (0 , 0 ) # No explicit padding needed now
108+ current_padding [1 ] = (0 , 0 )
112109 else :
113- # Update depth padding needed
114110 current_padding [1 ] = (padding_needed , 0 )
115111
116- # Apply padding if any dimension requires it
117112 padding_to_apply = tuple (current_padding )
118113 if any (p > 0 for dim_pads in padding_to_apply for p in dim_pads ):
119114 x_internal = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
120115 else :
121116 x_internal = x
122117
123- h_dim_after_conv_padding = x_internal .shape [2 ]
124- pad_h_fsdp = 0
125- if self .mesh and 'fsdp' in self .mesh .axis_names :
126- fsdp_size = self .mesh .shape ['fsdp' ]
127- if fsdp_size > 1 :
128- if h_dim_after_conv_padding % fsdp_size != 0 :
129- pad_h_fsdp = fsdp_size - (h_dim_after_conv_padding % fsdp_size )
130- h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
131- x_internal = jnp .pad (x_internal , h_padding , mode = "constant" , constant_values = 0.0 )
118+ # REMOVED FSDP PADDING LOGIC FROM HERE
119+ # Sharding constraints are fine, but JAX will error if not divisible.
120+ # This will be handled in the calling block.
132121 if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
133122 x_internal = jax .lax .with_sharding_constraint (x_internal , P (None , None , 'fsdp' , None , None ))
134123
135124 out = self .conv (x_internal )
136125 return out
137126
138127
128+
139129class WanRMS_norm (nnx .Module ):
140130
141131 def __init__ (
@@ -328,6 +318,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
328318 # Input x: (N, D, H, W, C), assume C = self.dim
329319 b , t , h , w , c = x .shape
330320 assert c == self .dim
321+ original_h = h
331322
332323 if self .mode == "upsample3d" :
333324 if feat_cache is not None :
@@ -351,32 +342,37 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
351342 x = x .reshape (b , t , h , w , 2 , c )
352343 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
353344 x = x .reshape (b , t * 2 , h , w , c )
345+ # Update t and h as they might have changed in upsample3d
354346 t = x .shape [1 ]
355347 h = x .shape [2 ]
348+ # original_h remains the height *before* this block's operations
356349
357350 x_reshaped = x .reshape (b * t , h , w , c )
351+ current_h = x_reshaped .shape [1 ]
358352
359- original_h = x_reshaped . shape [ 1 ]
353+ # --- FSDP Spatial Padding ---
360354 pad_h_fsdp = 0
361355 if self .mesh and 'fsdp' in self .mesh .axis_names :
362356 fsdp_size = self .mesh .shape ['fsdp' ]
363357 if fsdp_size > 1 :
364- if original_h % fsdp_size != 0 :
365- pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
358+ if current_h % fsdp_size != 0 :
359+ pad_h_fsdp = fsdp_size - (current_h % fsdp_size )
366360 h_padding = ((0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
367361 x_reshaped = jnp .pad (x_reshaped , h_padding , mode = "constant" , constant_values = 0.0 )
362+ # --- End FSDP Spatial Padding ---
368363
369364 if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
370365 x_reshaped = jax .lax .with_sharding_constraint (x_reshaped , P (None , 'fsdp' , None , None ))
371366
372367 resampled_x = self .resample (x_reshaped )
373368
369+ # --- FSDP Spatial Slicing ---
374370 if pad_h_fsdp > 0 :
375371 if "upsample" in self .mode :
376372 scale_factor_h = 1.0
377373 if isinstance (self .resample , nnx .Sequential ) and isinstance (self .resample .layers [0 ], WanUpsample ):
378374 scale_factor_h = self .resample .layers [0 ].scale_factor [0 ]
379- target_h = int (original_h * scale_factor_h )
375+ target_h = int (current_h * scale_factor_h )
380376 resampled_x = resampled_x [:, :target_h , :, :]
381377 elif "downsample" in self .mode :
382378 stride_h = 1
@@ -386,8 +382,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
386382 stride_h = self .resample .strides [0 ]
387383
388384 if stride_h > 1 :
389- target_h = original_h // stride_h
385+ # kernel_size and padding affect output size,
386+ # For "VALID" in ZeroPaddedConv2D (which has no other padding), out = (in - kernel + stride) // stride
387+ # Since we added padding for FSDP, we want the size as if no FSDP padding was added.
388+ k_h = self .resample .conv .kernel_size [0 ]
389+ target_h = (current_h - k_h + stride_h ) // stride_h
390390 resampled_x = resampled_x [:, :target_h , :, :]
391+ # If stride_h is 1, no slicing needed as the size doesn't shrink.
392+ # --- End FSDP Spatial Slicing ---
391393
392394 h_new , w_new , c_new = resampled_x .shape [1 :]
393395 x = resampled_x .reshape (b , t , h_new , w_new , c_new )
@@ -403,10 +405,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
403405 x = self .time_conv (jnp .concatenate ([feat_cache [idx ][:, - 1 :, :, :, :], x ], axis = 1 ))
404406 feat_cache [idx ] = cache_x
405407 feat_idx [0 ] += 1
406-
407408 return x
408409
409410
411+
410412class WanResidualBlock (nnx .Module ):
411413
412414 def __init__ (
@@ -421,6 +423,7 @@ def __init__(
421423 weights_dtype : jnp .dtype = jnp .float32 ,
422424 precision : jax .lax .Precision = None ,
423425 ):
426+ self .mesh = mesh
424427 self .nonlinearity = get_activation (non_linearity )
425428
426429 # layers
@@ -464,39 +467,54 @@ def __init__(
464467 )
465468
466469 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
467- # Apply shortcut connection
468- h = self .conv_shortcut (x )
470+ original_shape = x .shape
471+ original_h = original_shape [2 ]
472+ original_w = original_shape [3 ]
473+ pad_h_fsdp = 0
474+ pad_w_fsdp = 0
475+ x_padded = x
469476
470- x = self .norm1 (x )
471- x = self .nonlinearity (x )
477+ if self .mesh and 'fsdp' in self .mesh .axis_names :
478+ fsdp_size = self .mesh .shape ['fsdp' ]
479+ if fsdp_size > 1 :
480+ if original_h % fsdp_size != 0 :
481+ pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
482+ # Assuming width is not sharded on fsdp, add if needed
483+ # if original_w % fsdp_size != 0:
484+ # pad_w_fsdp = fsdp_size - (original_w % fsdp_size)
472485
473- if feat_cache is not None :
474- idx = feat_idx [0 ]
475- cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
476- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
477- cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
478- x = self .conv1 (x , feat_cache [idx ], idx )
479- feat_cache [idx ] = cache_x
480- feat_idx [0 ] += 1
481- else :
482- x = self .conv1 (x )
486+ if pad_h_fsdp > 0 or pad_w_fsdp > 0 :
487+ h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , pad_w_fsdp ), (0 , 0 ))
488+ x_padded = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
489+
490+ h = self .conv_shortcut (x_padded )
491+
492+ temp_x = self .norm1 (x_padded )
493+ temp_x = self .nonlinearity (temp_x )
494+ temp_x = self .conv1 (temp_x , cache_x = feat_cache [idx ] if feat_cache else None )
495+ temp_x = self .norm2 (temp_x )
496+ temp_x = self .nonlinearity (temp_x )
497+ temp_x = self .conv2 (temp_x , cache_x = feat_cache [idx ] if feat_cache else None )
498+
499+ # --- Crop temp_x to match h's spatial dimensions ---
500+ h_height , h_width = h .shape [2 ], h .shape [3 ]
501+ x_height , x_width = temp_x .shape [2 ], temp_x .shape [3 ]
502+
503+ if x_height > h_height :
504+ ch = (x_height - h_height ) // 2
505+ temp_x = temp_x [:, :, ch :ch + h_height , :, :]
506+ if x_width > h_width :
507+ cw = (x_width - h_width ) // 2
508+ temp_x = temp_x [:, :, :, cw :cw + h_width , :]
509+ # --- End Crop ---
510+
511+ res_x = temp_x + h
512+
513+ if pad_h_fsdp > 0 or pad_w_fsdp > 0 :
514+ res_x = res_x [:, :, :original_h , :original_w , :]
515+ return res_x
483516
484- x = self .norm2 (x )
485- x = self .nonlinearity (x )
486- idx = feat_idx [0 ]
487517
488- if feat_cache is not None :
489- idx = feat_idx [0 ]
490- cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
491- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
492- cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
493- x = self .conv2 (x , feat_cache [idx ])
494- feat_cache [idx ] = cache_x
495- feat_idx [0 ] += 1
496- else :
497- x = self .conv2 (x )
498- x = x + h
499- return x
500518
501519
502520class WanAttentionBlock (nnx .Module ):
@@ -535,25 +553,27 @@ def __init__(
535553 )
536554
537555 def __call__ (self , x : jax .Array ):
538-
539556 identity = x
540557 batch_size , time , height , width , channels = x .shape
541558 original_h = height
542559
560+ # --- FSDP Spatial Padding ---
543561 pad_h_fsdp = 0
562+ x_padded = x
544563 if self .mesh and 'fsdp' in self .mesh .axis_names :
545564 fsdp_size = self .mesh .shape ['fsdp' ]
546565 if fsdp_size > 1 :
547566 if original_h % fsdp_size != 0 :
548567 pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
549568 h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
550- x = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
569+ x_padded = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
570+ # --- End FSDP Spatial Padding ---
551571
552572 if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
553- x = jax .lax .with_sharding_constraint (x , P (None , None , 'fsdp' , None , None ))
573+ x_padded = jax .lax .with_sharding_constraint (x_padded , P (None , None , 'fsdp' , None , None ))
554574
555- current_height = x .shape [2 ]
556- x_reshaped = x .reshape (batch_size * time , current_height , width , channels )
575+ current_height = x_padded .shape [2 ]
576+ x_reshaped = x_padded .reshape (batch_size * time , current_height , width , channels )
557577 x_normed = self .norm (x_reshaped )
558578
559579 qkv = self .to_qkv (x_normed )
@@ -568,12 +588,16 @@ def __call__(self, x: jax.Array):
568588
569589 x_proj = self .proj (attn_out )
570590 x_proj = x_proj .reshape (batch_size , time , current_height , width , channels )
591+
592+ # --- FSDP Spatial Slicing ---
571593 if pad_h_fsdp > 0 :
572594 x_proj = x_proj [:, :, :original_h , :, :]
595+ # --- End FSDP Spatial Slicing ---
573596
574597 return x_proj + identity
575598
576599
600+
577601class WanMidBlock (nnx .Module ):
578602
579603 def __init__ (
0 commit comments