@@ -57,6 +57,8 @@ def __init__(
5757 weights_dtype : jnp .dtype = jnp .float32 ,
5858 precision : jax .lax .Precision = None ,
5959 ):
60+
61+ self .mesh = mesh
6062 self .kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
6163 self .stride = _canonicalize_tuple (stride , 3 , "stride" )
6264 padding_tuple = _canonicalize_tuple (padding , 3 , "padding" ) # (D, H, W) padding amounts
@@ -114,11 +116,23 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
114116 # Apply padding if any dimension requires it
115117 padding_to_apply = tuple (current_padding )
116118 if any (p > 0 for dim_pads in padding_to_apply for p in dim_pads ):
117- x_padded = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
119+ x_internal = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
118120 else :
119- x_padded = x
120- x_padded = jax .lax .with_sharding_constraint (x_padded , P (None , None , 'fsdp' , None , None ))
121- out = self .conv (x_padded )
121+ x_internal = x
122+
123+ h_dim_after_conv_padding = x_internal .shape [2 ]
124+ pad_h_fsdp = 0
125+ if self .mesh and 'fsdp' in self .mesh .axis_names :
126+ fsdp_size = self .mesh .shape ['fsdp' ]
127+ if fsdp_size > 1 :
128+ if h_dim_after_conv_padding % fsdp_size != 0 :
129+ pad_h_fsdp = fsdp_size - (h_dim_after_conv_padding % fsdp_size )
130+ h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
131+ x_internal = jnp .pad (x_internal , h_padding , mode = "constant" , constant_values = 0.0 )
132+ if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
133+ x_internal = jax .lax .with_sharding_constraint (x_internal , P (None , None , 'fsdp' , None , None ))
134+
135+ out = self .conv (x_internal )
122136 return out
123137
124138
@@ -225,6 +239,7 @@ def __init__(
225239 weights_dtype : jnp .dtype = jnp .float32 ,
226240 precision : jax .lax .Precision = None ,
227241 ):
242+ self .mesh = mesh
228243 self .dim = dim
229244 self .mode = mode
230245 self .time_conv = nnx .data (None )
@@ -336,12 +351,46 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
336351 x = x .reshape (b , t , h , w , 2 , c )
337352 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
338353 x = x .reshape (b , t * 2 , h , w , c )
339- t = x .shape [1 ]
340- x = x .reshape (b * t , h , w , c )
341- x = jax .lax .with_sharding_constraint (x , P (None , 'fsdp' , None , None ))
342- x = self .resample (x )
343- h_new , w_new , c_new = x .shape [1 :]
344- x = x .reshape (b , t , h_new , w_new , c_new )
354+ t = x .shape [1 ]
355+ h = x .shape [2 ]
356+
357+ x_reshaped = x .reshape (b * t , h , w , c )
358+
359+ original_h = x_reshaped .shape [1 ]
360+ pad_h_fsdp = 0
361+ if self .mesh and 'fsdp' in self .mesh .axis_names :
362+ fsdp_size = self .mesh .shape ['fsdp' ]
363+ if fsdp_size > 1 :
364+ if original_h % fsdp_size != 0 :
365+ pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
366+ h_padding = ((0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
367+ x_reshaped = jnp .pad (x_reshaped , h_padding , mode = "constant" , constant_values = 0.0 )
368+
369+ if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
370+ x_reshaped = jax .lax .with_sharding_constraint (x_reshaped , P (None , 'fsdp' , None , None ))
371+
372+ resampled_x = self .resample (x_reshaped )
373+
374+ if pad_h_fsdp > 0 :
375+ if "upsample" in self .mode :
376+ scale_factor_h = 1.0
377+ if isinstance (self .resample , nnx .Sequential ) and isinstance (self .resample .layers [0 ], WanUpsample ):
378+ scale_factor_h = self .resample .layers [0 ].scale_factor [0 ]
379+ target_h = int (original_h * scale_factor_h )
380+ resampled_x = resampled_x [:, :target_h , :, :]
381+ elif "downsample" in self .mode :
382+ stride_h = 1
383+ if isinstance (self .resample , ZeroPaddedConv2D ):
384+ stride_h = self .resample .conv .strides [0 ]
385+ elif isinstance (self .resample , nnx .Conv ):
386+ stride_h = self .resample .strides [0 ]
387+
388+ if stride_h > 1 :
389+ target_h = original_h // stride_h
390+ resampled_x = resampled_x [:, :target_h , :, :]
391+
392+ h_new , w_new , c_new = resampled_x .shape [1 :]
393+ x = resampled_x .reshape (b , t , h_new , w_new , c_new )
345394
346395 if self .mode == "downsample3d" :
347396 if feat_cache is not None :
@@ -461,6 +510,7 @@ def __init__(
461510 weights_dtype : jnp .dtype = jnp .float32 ,
462511 precision : jax .lax .Precision = None ,
463512 ):
513+ self .mesh = mesh
464514 self .dim = dim
465515 self .norm = WanRMS_norm (rngs = rngs , dim = dim , channel_first = False )
466516 self .to_qkv = nnx .Conv (
@@ -488,29 +538,40 @@ def __call__(self, x: jax.Array):
488538
489539 identity = x
490540 batch_size , time , height , width , channels = x .shape
541+ original_h = height
542+
543+ pad_h_fsdp = 0
544+ if self .mesh and 'fsdp' in self .mesh .axis_names :
545+ fsdp_size = self .mesh .shape ['fsdp' ]
546+ if fsdp_size > 1 :
547+ if original_h % fsdp_size != 0 :
548+ pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
549+ h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
550+ x = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
491551
492- x = jax .lax .with_sharding_constraint (x , P (None , None , 'fsdp' , None , None ))
552+ if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
553+ x = jax .lax .with_sharding_constraint (x , P (None , None , 'fsdp' , None , None ))
493554
494- x = x .reshape (batch_size * time , height , width , channels )
495- x = self .norm (x )
555+ current_height = x .shape [2 ]
556+ x_reshaped = x .reshape (batch_size * time , current_height , width , channels )
557+ x_normed = self .norm (x_reshaped )
496558
497- qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
498- # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
559+ qkv = self .to_qkv (x_normed )
499560 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
500561 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
501562 q , k , v = jnp .split (qkv , 3 , axis = - 2 )
502563 q = jnp .transpose (q , (0 , 1 , 3 , 2 ))
503564 k = jnp .transpose (k , (0 , 1 , 3 , 2 ))
504565 v = jnp .transpose (v , (0 , 1 , 3 , 2 ))
505- x = jax .nn .dot_product_attention (q , k , v )
506- x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
566+ attn_out = jax .nn .dot_product_attention (q , k , v )
567+ attn_out = jnp .squeeze (attn_out , 1 ).reshape (batch_size * time , current_height , width , channels )
507568
508- # output projection
509- x = self . proj ( x )
510- # Reshape back
511- x = x . reshape ( batch_size , time , height , width , channels )
569+ x_proj = self . proj ( attn_out )
570+ x_proj = x_proj . reshape ( batch_size , time , current_height , width , channels )
571+ if pad_h_fsdp > 0 :
572+ x_proj = x_proj [:, :, : original_h , :, :]
512573
513- return x + identity
574+ return x_proj + identity
514575
515576
516577class WanMidBlock (nnx .Module ):
0 commit comments