@@ -57,8 +57,6 @@ def __init__(
5757 weights_dtype : jnp .dtype = jnp .float32 ,
5858 precision : jax .lax .Precision = None ,
5959 ):
60-
61- self .mesh = mesh
6260 self .kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
6361 self .stride = _canonicalize_tuple (stride , 3 , "stride" )
6462 padding_tuple = _canonicalize_tuple (padding , 3 , "padding" ) # (D, H, W) padding amounts
@@ -95,37 +93,35 @@ def __init__(
9593 )
9694
9795 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
98- current_padding = list (self ._causal_padding )
96+ current_padding = list (self ._causal_padding ) # Mutable copy
9997 padding_needed = self ._depth_padding_before
10098
10199 if cache_x is not None and padding_needed > 0 :
100+ # Ensure cache has same spatial/channel dims, potentially different depth
102101 assert cache_x .shape [0 ] == x .shape [0 ] and cache_x .shape [2 :] == x .shape [2 :], "Cache spatial/channel dims mismatch"
103102 cache_len = cache_x .shape [1 ]
104- x = jnp .concatenate ([cache_x , x ], axis = 1 )
103+ x = jnp .concatenate ([cache_x , x ], axis = 1 ) # Concat along depth (D)
104+
105105 padding_needed -= cache_len
106106 if padding_needed < 0 :
107+ # Cache longer than needed padding, trim from start
107108 x = x [:, - padding_needed :, ...]
108- current_padding [1 ] = (0 , 0 )
109+ current_padding [1 ] = (0 , 0 ) # No explicit padding needed now
109110 else :
111+ # Update depth padding needed
110112 current_padding [1 ] = (padding_needed , 0 )
111113
114+ # Apply padding if any dimension requires it
112115 padding_to_apply = tuple (current_padding )
113116 if any (p > 0 for dim_pads in padding_to_apply for p in dim_pads ):
114- x_internal = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
117+ x_padded = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
115118 else :
116- x_internal = x
117-
118- # REMOVED FSDP PADDING LOGIC FROM HERE
119- # Sharding constraints are fine, but JAX will error if not divisible.
120- # This will be handled in the calling block.
121- if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
122- x_internal = jax .lax .with_sharding_constraint (x_internal , P (None , None , 'fsdp' , None , None ))
123-
124- out = self .conv (x_internal )
119+ x_padded = x
120+ x_padded = jax .lax .with_sharding_constraint (x_padded , P (None , None , 'fsdp' , None , None ))
121+ out = self .conv (x_padded )
125122 return out
126123
127124
128-
129125class WanRMS_norm (nnx .Module ):
130126
131127 def __init__ (
@@ -229,7 +225,6 @@ def __init__(
229225 weights_dtype : jnp .dtype = jnp .float32 ,
230226 precision : jax .lax .Precision = None ,
231227 ):
232- self .mesh = mesh
233228 self .dim = dim
234229 self .mode = mode
235230 self .time_conv = nnx .data (None )
@@ -318,7 +313,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
318313 # Input x: (N, D, H, W, C), assume C = self.dim
319314 b , t , h , w , c = x .shape
320315 assert c == self .dim
321- original_h = h
322316
323317 if self .mode == "upsample3d" :
324318 if feat_cache is not None :
@@ -342,57 +336,12 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
342336 x = x .reshape (b , t , h , w , 2 , c )
343337 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
344338 x = x .reshape (b , t * 2 , h , w , c )
345- # Update t and h as they might have changed in upsample3d
346- t = x .shape [1 ]
347- h = x .shape [2 ]
348- # original_h remains the height *before* this block's operations
349-
350- x_reshaped = x .reshape (b * t , h , w , c )
351- current_h = x_reshaped .shape [1 ]
352-
353- # --- FSDP Spatial Padding ---
354- pad_h_fsdp = 0
355- if self .mesh and 'fsdp' in self .mesh .axis_names :
356- fsdp_size = self .mesh .shape ['fsdp' ]
357- if fsdp_size > 1 :
358- if current_h % fsdp_size != 0 :
359- pad_h_fsdp = fsdp_size - (current_h % fsdp_size )
360- h_padding = ((0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
361- x_reshaped = jnp .pad (x_reshaped , h_padding , mode = "constant" , constant_values = 0.0 )
362- # --- End FSDP Spatial Padding ---
363-
364- if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
365- x_reshaped = jax .lax .with_sharding_constraint (x_reshaped , P (None , 'fsdp' , None , None ))
366-
367- resampled_x = self .resample (x_reshaped )
368-
369- # --- FSDP Spatial Slicing ---
370- if pad_h_fsdp > 0 :
371- if "upsample" in self .mode :
372- scale_factor_h = 1.0
373- if isinstance (self .resample , nnx .Sequential ) and isinstance (self .resample .layers [0 ], WanUpsample ):
374- scale_factor_h = self .resample .layers [0 ].scale_factor [0 ]
375- target_h = int (current_h * scale_factor_h )
376- resampled_x = resampled_x [:, :target_h , :, :]
377- elif "downsample" in self .mode :
378- stride_h = 1
379- if isinstance (self .resample , ZeroPaddedConv2D ):
380- stride_h = self .resample .conv .strides [0 ]
381- elif isinstance (self .resample , nnx .Conv ):
382- stride_h = self .resample .strides [0 ]
383-
384- if stride_h > 1 :
385- # kernel_size and padding affect output size,
386- # For "VALID" in ZeroPaddedConv2D (which has no other padding), out = (in - kernel + stride) // stride
387- # Since we added padding for FSDP, we want the size as if no FSDP padding was added.
388- k_h = self .resample .conv .kernel_size [0 ]
389- target_h = (current_h - k_h + stride_h ) // stride_h
390- resampled_x = resampled_x [:, :target_h , :, :]
391- # If stride_h is 1, no slicing needed as the size doesn't shrink.
392- # --- End FSDP Spatial Slicing ---
393-
394- h_new , w_new , c_new = resampled_x .shape [1 :]
395- x = resampled_x .reshape (b , t , h_new , w_new , c_new )
339+ t = x .shape [1 ]
340+ x = x .reshape (b * t , h , w , c )
341+ x = jax .lax .with_sharding_constraint (x , P (None , 'fsdp' , None , None ))
342+ x = self .resample (x )
343+ h_new , w_new , c_new = x .shape [1 :]
344+ x = x .reshape (b , t , h_new , w_new , c_new )
396345
397346 if self .mode == "downsample3d" :
398347 if feat_cache is not None :
@@ -405,8 +354,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
405354 x = self .time_conv (jnp .concatenate ([feat_cache [idx ][:, - 1 :, :, :, :], x ], axis = 1 ))
406355 feat_cache [idx ] = cache_x
407356 feat_idx [0 ] += 1
408- return x
409357
358+ return x
410359
411360
412361class WanResidualBlock (nnx .Module ):
@@ -423,7 +372,6 @@ def __init__(
423372 weights_dtype : jnp .dtype = jnp .float32 ,
424373 precision : jax .lax .Precision = None ,
425374 ):
426- self .mesh = mesh
427375 self .nonlinearity = get_activation (non_linearity )
428376
429377 # layers
@@ -467,54 +415,39 @@ def __init__(
467415 )
468416
469417 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
470- original_shape = x .shape
471- original_h = original_shape [2 ]
472- original_w = original_shape [3 ]
473- pad_h_fsdp = 0
474- pad_w_fsdp = 0
475- x_padded = x
476-
477- if self .mesh and 'fsdp' in self .mesh .axis_names :
478- fsdp_size = self .mesh .shape ['fsdp' ]
479- if fsdp_size > 1 :
480- if original_h % fsdp_size != 0 :
481- pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
482- # Assuming width is not sharded on fsdp, add if needed
483- # if original_w % fsdp_size != 0:
484- # pad_w_fsdp = fsdp_size - (original_w % fsdp_size)
485-
486- if pad_h_fsdp > 0 or pad_w_fsdp > 0 :
487- h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , pad_w_fsdp ), (0 , 0 ))
488- x_padded = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
489-
490- h = self .conv_shortcut (x_padded )
418+ # Apply shortcut connection
419+ h = self .conv_shortcut (x )
491420
492- temp_x = self .norm1 (x_padded )
493- temp_x = self .nonlinearity (temp_x )
494- temp_x = self .conv1 (temp_x , cache_x = feat_cache [idx ] if feat_cache else None )
495- temp_x = self .norm2 (temp_x )
496- temp_x = self .nonlinearity (temp_x )
497- temp_x = self .conv2 (temp_x , cache_x = feat_cache [idx ] if feat_cache else None )
498-
499- # --- Crop temp_x to match h's spatial dimensions ---
500- h_height , h_width = h .shape [2 ], h .shape [3 ]
501- x_height , x_width = temp_x .shape [2 ], temp_x .shape [3 ]
502-
503- if x_height > h_height :
504- ch = (x_height - h_height ) // 2
505- temp_x = temp_x [:, :, ch :ch + h_height , :, :]
506- if x_width > h_width :
507- cw = (x_width - h_width ) // 2
508- temp_x = temp_x [:, :, :, cw :cw + h_width , :]
509- # --- End Crop ---
510-
511- res_x = temp_x + h
421+ x = self .norm1 (x )
422+ x = self .nonlinearity (x )
512423
513- if pad_h_fsdp > 0 or pad_w_fsdp > 0 :
514- res_x = res_x [:, :, :original_h , :original_w , :]
515- return res_x
424+ if feat_cache is not None :
425+ idx = feat_idx [0 ]
426+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
427+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
428+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
429+ x = self .conv1 (x , feat_cache [idx ], idx )
430+ feat_cache [idx ] = cache_x
431+ feat_idx [0 ] += 1
432+ else :
433+ x = self .conv1 (x )
516434
435+ x = self .norm2 (x )
436+ x = self .nonlinearity (x )
437+ idx = feat_idx [0 ]
517438
439+ if feat_cache is not None :
440+ idx = feat_idx [0 ]
441+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
442+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
443+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
444+ x = self .conv2 (x , feat_cache [idx ])
445+ feat_cache [idx ] = cache_x
446+ feat_idx [0 ] += 1
447+ else :
448+ x = self .conv2 (x )
449+ x = x + h
450+ return x
518451
519452
520453class WanAttentionBlock (nnx .Module ):
@@ -528,7 +461,6 @@ def __init__(
528461 weights_dtype : jnp .dtype = jnp .float32 ,
529462 precision : jax .lax .Precision = None ,
530463 ):
531- self .mesh = mesh
532464 self .dim = dim
533465 self .norm = WanRMS_norm (rngs = rngs , dim = dim , channel_first = False )
534466 self .to_qkv = nnx .Conv (
@@ -553,49 +485,32 @@ def __init__(
553485 )
554486
555487 def __call__ (self , x : jax .Array ):
488+
556489 identity = x
557490 batch_size , time , height , width , channels = x .shape
558- original_h = height
559-
560- # --- FSDP Spatial Padding ---
561- pad_h_fsdp = 0
562- x_padded = x
563- if self .mesh and 'fsdp' in self .mesh .axis_names :
564- fsdp_size = self .mesh .shape ['fsdp' ]
565- if fsdp_size > 1 :
566- if original_h % fsdp_size != 0 :
567- pad_h_fsdp = fsdp_size - (original_h % fsdp_size )
568- h_padding = ((0 , 0 ), (0 , 0 ), (0 , pad_h_fsdp ), (0 , 0 ), (0 , 0 ))
569- x_padded = jnp .pad (x , h_padding , mode = "constant" , constant_values = 0.0 )
570- # --- End FSDP Spatial Padding ---
571-
572- if self .mesh and 'fsdp' in self .mesh .axis_names and self .mesh .shape ['fsdp' ] > 1 :
573- x_padded = jax .lax .with_sharding_constraint (x_padded , P (None , None , 'fsdp' , None , None ))
574-
575- current_height = x_padded .shape [2 ]
576- x_reshaped = x_padded .reshape (batch_size * time , current_height , width , channels )
577- x_normed = self .norm (x_reshaped )
578-
579- qkv = self .to_qkv (x_normed )
491+
492+ x = jax .lax .with_sharding_constraint (x , P (None , None , 'fsdp' , None , None ))
493+
494+ x = x .reshape (batch_size * time , height , width , channels )
495+ x = self .norm (x )
496+
497+ qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
498+ # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
580499 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
581500 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
582501 q , k , v = jnp .split (qkv , 3 , axis = - 2 )
583502 q = jnp .transpose (q , (0 , 1 , 3 , 2 ))
584503 k = jnp .transpose (k , (0 , 1 , 3 , 2 ))
585504 v = jnp .transpose (v , (0 , 1 , 3 , 2 ))
586- attn_out = jax .nn .dot_product_attention (q , k , v )
587- attn_out = jnp .squeeze (attn_out , 1 ).reshape (batch_size * time , current_height , width , channels )
588-
589- x_proj = self .proj (attn_out )
590- x_proj = x_proj .reshape (batch_size , time , current_height , width , channels )
591-
592- # --- FSDP Spatial Slicing ---
593- if pad_h_fsdp > 0 :
594- x_proj = x_proj [:, :, :original_h , :, :]
595- # --- End FSDP Spatial Slicing ---
505+ x = jax .nn .dot_product_attention (q , k , v )
506+ x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
596507
597- return x_proj + identity
508+ # output projection
509+ x = self .proj (x )
510+ # Reshape back
511+ x = x .reshape (batch_size , time , height , width , channels )
598512
513+ return x + identity
599514
600515
601516class WanMidBlock (nnx .Module ):
@@ -1234,4 +1149,4 @@ def decode(
12341149 decoded = self ._decode (z , feat_cache ).sample
12351150 if not return_dict :
12361151 return (decoded ,)
1237- return FlaxDecoderOutput (sample = decoded )
1152+ return FlaxDecoderOutput (sample = decoded )
0 commit comments