@@ -485,15 +485,15 @@ def __init__(
485485 self .conv_shortcut = Identity ()
486486 if in_dim != out_dim :
487487 self .conv_shortcut = WanCausalConv3d (
488- rngs = rngs ,
489- in_channels = in_dim ,
490- out_channels = out_dim ,
491- kernel_size = 1 ,
492- mesh = mesh ,
493- dtype = dtype ,
494- weights_dtype = weights_dtype ,
495- precision = precision ,
496- )
488+ rngs = rngs ,
489+ in_channels = in_dim ,
490+ out_channels = out_dim ,
491+ kernel_size = 1 ,
492+ mesh = mesh ,
493+ dtype = dtype ,
494+ weights_dtype = weights_dtype ,
495+ precision = precision ,
496+ )
497497
498498 def initialize_cache (self , batch_size , height , width , dtype ):
499499 """Initialize cache for all convolutions."""
@@ -572,42 +572,42 @@ def __init__(
572572 )
573573
574574 def __call__ (self , x : jax .Array ):
575- identity = x
576- batch_size , time , height , width , channels = x .shape
577-
575+ identity = x
576+ batch_size , time , height , width , channels = x .shape
577+
578578 # Reshape to process all frames together
579- x = x .reshape (batch_size * time , height , width , channels )
580- x = self .norm (x )
581-
582- qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583-
584- # Get actual shape after to_qkv to avoid using stale variables
585- bt , h , w , c3 = qkv .shape
586-
587- # Flatten spatial dimensions for attention
588- qkv = qkv .reshape (bt , h * w , c3 ) # (B*T, H*W, C*3)
589- qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
590-
591- q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
592- q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
593- k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
594- v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
595-
596- # Add head dimension for dot_product_attention
597- q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
598- k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
599- v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
600-
601- x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
602- x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
603-
604- # Reshape back to spatial dimensions
605- x = x .reshape (bt , h , w , channels )
606- x = self .proj (x )
607-
579+ x = x .reshape (batch_size * time , height , width , channels )
580+ x = self .norm (x )
581+
582+ qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583+
584+ # Get actual shape after to_qkv to avoid using stale variables
585+ bt , h , w , c3 = qkv .shape
586+
587+ # Flatten spatial dimensions for attention
588+ qkv = qkv .reshape (bt , h * w , c3 ) # (B*T, H*W, C*3)
589+ qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
590+
591+ q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
592+ q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
593+ k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
594+ v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
595+
596+ # Add head dimension for dot_product_attention
597+ q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
598+ k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
599+ v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
600+
601+ x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
602+ x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
603+
604+ # Reshape back to spatial dimensions
605+ x = x .reshape (bt , h , w , channels )
606+ x = self .proj (x )
607+
608608 # Reshape back to original shape
609- x = x .reshape (batch_size , time , height , width , channels )
610- return x + identity
609+ x = x .reshape (batch_size , time , height , width , channels )
610+ return x + identity
611611
612612
613613class WanMidBlock (nnx .Module ):
@@ -626,18 +626,18 @@ def __init__(
626626 self .dim = dim
627627 self .resnets = nnx .List (
628628 [
629- WanResidualBlock (
630- in_dim = dim ,
631- out_dim = dim ,
632- rngs = rngs ,
633- dropout = dropout ,
634- non_linearity = non_linearity ,
635- mesh = mesh ,
636- dtype = dtype ,
637- weights_dtype = weights_dtype ,
638- precision = precision ,
639- )
640- ]
629+ WanResidualBlock (
630+ in_dim = dim ,
631+ out_dim = dim ,
632+ rngs = rngs ,
633+ dropout = dropout ,
634+ non_linearity = non_linearity ,
635+ mesh = mesh ,
636+ dtype = dtype ,
637+ weights_dtype = weights_dtype ,
638+ precision = precision ,
639+ )
640+ ]
641641 )
642642 self .attentions = nnx .List ([])
643643 for _ in range (num_layers ):
@@ -991,18 +991,18 @@ def __init__(
991991 upsample_mode = "upsample3d" if temperal_upsample [i ] else "upsample2d"
992992 self .up_blocks .append (
993993 WanUpBlock (
994- in_dim = in_dim ,
995- out_dim = out_dim ,
996- num_res_blocks = num_res_blocks ,
997- dropout = dropout ,
998- upsample_mode = upsample_mode ,
999- non_linearity = non_linearity ,
1000- rngs = rngs ,
1001- mesh = mesh ,
1002- dtype = dtype ,
1003- weights_dtype = weights_dtype ,
1004- precision = precision ,
1005- )
994+ in_dim = in_dim ,
995+ out_dim = out_dim ,
996+ num_res_blocks = num_res_blocks ,
997+ dropout = dropout ,
998+ upsample_mode = upsample_mode ,
999+ non_linearity = non_linearity ,
1000+ rngs = rngs ,
1001+ mesh = mesh ,
1002+ dtype = dtype ,
1003+ weights_dtype = weights_dtype ,
1004+ precision = precision ,
1005+ )
10061006 )
10071007
10081008 self .norm_out = WanRMS_norm (
@@ -1176,22 +1176,44 @@ def encode(
11761176 if x .shape [- 1 ] != 3 :
11771177 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
11781178
1179- x_scan = jnp .swapaxes (x , 0 , 1 ) # (B, T, H, W, C) -> (T, B, H, W, C)
1179+ # Calculate temporal downsampling factor
1180+ temporal_downsample_factor = 1
1181+ for ds in self .temperal_downsample :
1182+ if ds :
1183+ temporal_downsample_factor *= 2
1184+
11801185 b , t , h , w , c = x .shape
1186+
1187+ # Process frames in chunks that match temporal downsampling
1188+ # This prevents frames from being downsampled to 0
1189+ chunk_size = temporal_downsample_factor
1190+
1191+ # Pad time dimension if needed to make it divisible by chunk_size
1192+ if t % chunk_size != 0 :
1193+ pad_frames = chunk_size - (t % chunk_size )
1194+ x = jnp .pad (x , ((0 , 0 ), (0 , pad_frames ), (0 , 0 ), (0 , 0 ), (0 , 0 )), mode = 'edge' )
1195+ t = x .shape [1 ]
1196+
1197+ # Reshape to process chunks: (B, T, H, W, C) -> (T//chunk_size, B, chunk_size, H, W, C)
1198+ x_chunks = x .reshape (b , t // chunk_size , chunk_size , h , w , c )
1199+ x_scan = jnp .swapaxes (x_chunks , 0 , 1 ) # -> (T//chunk_size, B, chunk_size, H, W, C)
1200+
11811201 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
11821202
1183- def scan_fn (carry , input_slice ):
1184- """Scan function processes one frame at a time."""
1185- # Expand time dimension for Conv3d compatibility
1186- input_slice = jnp .expand_dims (input_slice , 1 ) # (B, H, W, C) -> (B, 1, H, W, C)
1187- out_slice , new_carry = self .encoder (input_slice , carry )
1188- # Squeeze time dimension for scan stacking
1189- out_slice = jnp .squeeze (out_slice , 1 ) # (B, 1, H', W', C') -> (B, H', W', C')
1190- return new_carry , out_slice
1203+ def scan_fn (carry , input_chunk ):
1204+ """Scan function processes one chunk of frames at a time."""
1205+ # input_chunk shape: (B, chunk_size, H, W, C)
1206+ out_chunk , new_carry = self .encoder (input_chunk , carry )
1207+ return new_carry , out_chunk
11911208
11921209 # Use jax.lax.scan for JIT-compilable temporal iteration
1193- final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1194- encoded = jnp .swapaxes (encoded_frames , 0 , 1 ) # (T, B, H', W', C') -> (B, T, H', W', C')
1210+ final_cache , encoded_chunks = jax .lax .scan (scan_fn , init_cache , x_scan )
1211+ # encoded_chunks shape: (T//chunk_size, B, T_out_per_chunk, H', W', C')
1212+
1213+ # Reshape back: (T//chunk_size, B, T_out, H', W', C') -> (B, T_total, H', W', C')
1214+ n_chunks , batch , t_per_chunk , h_out , w_out , c_out = encoded_chunks .shape
1215+ encoded = jnp .transpose (encoded_chunks , (1 , 0 , 2 , 3 , 4 , 5 )) # (B, n_chunks, T_out, H', W', C')
1216+ encoded = encoded .reshape (batch , n_chunks * t_per_chunk , h_out , w_out , c_out )
11951217
11961218 # Apply quantization convolution
11971219 enc , _ = self .quant_conv (encoded )
@@ -1221,9 +1243,18 @@ def decode(
12211243
12221244 # Apply post-quantization convolution
12231245 x , _ = self .post_quant_conv (z )
1224- x_scan = jnp .swapaxes (x , 0 , 1 ) # (B, T, H, W, C) -> (T, B, H, W, C)
1225-
1246+
1247+ # Calculate temporal upsampling factor
1248+ temporal_upsample_factor = 1
1249+ for us in self .temporal_upsample :
1250+ if us :
1251+ temporal_upsample_factor *= 2
1252+
12261253 b , t , h , w , c = x .shape
1254+
1255+ # For decoder, we still process one frame at a time but output will be upsampled
1256+ x_scan = jnp .swapaxes (x , 0 , 1 ) # (B, T, H, W, C) -> (T, B, H, W, C)
1257+
12271258 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
12281259
12291260 def scan_fn (carry , input_slice ):
@@ -1238,11 +1269,11 @@ def scan_fn(carry, input_slice):
12381269 # Use jax.lax.scan for JIT-compilable temporal iteration
12391270 final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
12401271
1241- # decoded_frames shape: (T_lat, B, 4 , H, W, C)
1242- # Transpose to (B, T_lat, 4 , H, W, C)
1272+ # decoded_frames shape: (T_lat, B, T_upsample , H, W, C)
1273+ # Transpose to (B, T_lat, T_upsample , H, W, C)
12431274 decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
12441275
1245- # Reshape to (B, T_lat*4 , H, W, C)
1276+ # Reshape to (B, T_lat * T_upsample , H, W, C)
12461277 b , t_lat , t_sub , h , w , c = decoded .shape
12471278 decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
12481279
0 commit comments