2020import jax
2121import jax .numpy as jnp
2222from flax import nnx
23+ from jax .sharding import PartitionSpec
2324from ...configuration_utils import ConfigMixin
2425from ..modeling_flax_utils import FlaxModelMixin , get_activation
2526from ... import common_types
@@ -57,6 +58,7 @@ def __init__(
5758 self .kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
5859 self .stride = _canonicalize_tuple (stride , 3 , "stride" )
5960 padding_tuple = _canonicalize_tuple (padding , 3 , "padding" )
61+ self .mesh = mesh
6062
6163 self ._causal_padding = (
6264 (0 , 0 ),
@@ -90,9 +92,22 @@ def __init__(
9092 )
9193
9294 def initialize_cache (self , batch_size , height , width , dtype ):
93- return jnp .zeros ((batch_size , CACHE_T , height , width , self .conv .in_features ), dtype = dtype )
95+ # Create zeros
96+ cache = jnp .zeros ((batch_size , CACHE_T , height , width , self .conv .in_features ), dtype = dtype )
97+
98+ # OPTIMIZATION: Spatial Partitioning on Initialization
99+ # If we don't shard here, JAX allocates the full 2.64GB per chip, causing OOM.
100+ if self .mesh is not None :
101+ # Shard along Height (axis 2). Axis spec: (Batch, Time, Height, Width, Channels)
102+ # "fsdp" axis usually corresponds to the data parallelism or spatial split in this context.
103+ cache = jax .lax .with_sharding_constraint (cache , PartitionSpec (None , None , "fsdp" , None , None ))
104+ return cache
94105
95106 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None ) -> Tuple [jax .Array , jax .Array ]:
107+ # OPTIMIZATION: Spatial Partitioning during execution
108+ if self .mesh is not None :
109+ x = jax .lax .with_sharding_constraint (x , PartitionSpec (None , None , "fsdp" , None , None ))
110+
96111 current_padding = list (self ._causal_padding )
97112
98113 if cache_x is not None :
@@ -174,7 +189,6 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
174189 nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
175190 )
176191 elif mode == "upsample3d" :
177- # 3D mode ALSO needs Sequential for the spatial part to match checkpoints
178192 self .resample = nnx .Sequential (
179193 WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
180194 nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
@@ -219,8 +233,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
219233
220234 b , t , h , w , c = x .shape
221235 x = x .reshape (b * t , h , w , c )
222- x = self .upsample (x )
223- x = self .conv (x )
236+ x = self .resample (x ) # Sequential
224237 h_new , w_new , c_new = x .shape [1 :]
225238 x = x .reshape (b , t , h_new , w_new , c_new )
226239
@@ -234,7 +247,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
234247 elif self .mode == "downsample3d" :
235248 b , t , h , w , c = x .shape
236249 x = x .reshape (b * t , h , w , c )
237- x , _ = self .resample (x , None ) # Fixed: use self.resample not self.downsample_conv
250+ x , _ = self .resample (x , None ) # ZeroPaddedConv2D
238251 h_new , w_new , c_new = x .shape [1 :]
239252 x = x .reshape (b , t , h_new , w_new , c_new )
240253
@@ -532,8 +545,6 @@ def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult
532545 self .z_dim = z_dim
533546 self .temperal_downsample = temperal_downsample
534547 self .temporal_upsample = temperal_downsample [::- 1 ]
535-
536- # MISSING attributes added back
537548 self .latents_mean = latents_mean
538549 self .latents_std = latents_std
539550
0 commit comments