@@ -105,9 +105,10 @@ def initialize_cache(self, batch_size, height, width, dtype):
105105 if shard_axis is None and width % num_fsdp_devices == 0 :
106106 shard_width_axis = "fsdp"
107107
108+ # CRITICAL FIX: First axis is "data" (Batch), NOT None
108109 cache = jax .lax .with_sharding_constraint (
109110 cache ,
110- PartitionSpec (None , None , shard_axis , shard_width_axis , None )
111+ PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
111112 )
112113 return cache
113114
@@ -123,9 +124,10 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> Tuple[j
123124 if shard_axis is None and width % num_fsdp_devices == 0 :
124125 shard_width_axis = "fsdp"
125126
127+ # CRITICAL FIX: First axis is "data" (Batch), NOT None
126128 x = jax .lax .with_sharding_constraint (
127129 x ,
128- PartitionSpec (None , None , shard_axis , shard_width_axis , None )
130+ PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
129131 )
130132
131133 current_padding = list (self ._causal_padding )
@@ -238,7 +240,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
238240 if self .mode == "upsample2d" :
239241 b , t , h , w , c = x .shape
240242 x = x .reshape (b * t , h , w , c )
241- x = self .resample (x ) # Sequential
243+ x = self .resample (x )
242244 h_new , w_new , c_new = x .shape [1 :]
243245 x = x .reshape (b , t , h_new , w_new , c_new )
244246
@@ -260,14 +262,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
260262 elif self .mode == "downsample2d" :
261263 b , t , h , w , c = x .shape
262264 x = x .reshape (b * t , h , w , c )
263- x , _ = self .resample (x , None ) # ZeroPaddedConv2D
265+ x , _ = self .resample (x , None )
264266 h_new , w_new , c_new = x .shape [1 :]
265267 x = x .reshape (b , t , h_new , w_new , c_new )
266268
267269 elif self .mode == "downsample3d" :
268270 b , t , h , w , c = x .shape
269271 x = x .reshape (b * t , h , w , c )
270- x , _ = self .resample (x , None ) # ZeroPaddedConv2D
272+ x , _ = self .resample (x , None )
271273 h_new , w_new , c_new = x .shape [1 :]
272274 x = x .reshape (b , t , h_new , w_new , c_new )
273275
@@ -561,7 +563,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
561563
562564
563565class AutoencoderKLWan (nnx .Module , FlaxModelMixin , ConfigMixin ):
564- def __init__ (self , rngs : nnx .Rngs , base_dim : int = 96 , z_dim : int = 16 , dim_mult : Tuple [int ] = [1 , 2 , 4 , 4 ], num_res_blocks : int = 2 , attn_scales : List [float ] = [], temperal_downsample : List [bool ] = [False , True , True ], dropout : float = 0.0 , latents_mean : List [float ] = [], latents_std : List [float ] = [], mesh : jax .sharding .Mesh = None , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : jax .lax .Precision = None ):
566+ def __init__ (self , rngs : nnx .Rngs , base_dim : int = 96 , z_dim : int = 16 , dim_mult : Tuple [int ] = [1 , 2 , 4 , 4 ], num_res_blocks : int = 2 , attn_scales : List [float ] = [], temperal_downsample : List [bool ] = [False , True , True ], dropout = 0.0 , latents_mean : List [float ] = [], latents_std : List [float ] = [], mesh : jax .sharding .Mesh = None , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : jax .lax .Precision = None ):
565567 self .z_dim = z_dim
566568 self .temperal_downsample = temperal_downsample
567569 self .temporal_upsample = temperal_downsample [::- 1 ]
@@ -611,9 +613,10 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
611613 def scan_fn (carry , input_slice ):
612614 # Expand Time dimension for Conv3d
613615 input_slice = jnp .expand_dims (input_slice , 1 )
616+ # OPTIMIZATION: Force bfloat16 accumulation within the scan
617+ # to save memory on the massive output buffer
614618 out_slice , new_carry = self .decoder (input_slice , carry )
615619 out_slice = out_slice .astype (jnp .bfloat16 )
616- # Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
617620 return new_carry , out_slice
618621
619622 final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
0 commit comments