Skip to content

Commit c033b9a

Browse files
committed
Fic
1 parent eaa96e5 commit c033b9a

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ def initialize_cache(self, batch_size, height, width, dtype):
105105
if shard_axis is None and width % num_fsdp_devices == 0:
106106
shard_width_axis = "fsdp"
107107

108+
# CRITICAL FIX: First axis is "data" (Batch), NOT None
108109
cache = jax.lax.with_sharding_constraint(
109110
cache,
110-
PartitionSpec(None, None, shard_axis, shard_width_axis, None)
111+
PartitionSpec("data", None, shard_axis, shard_width_axis, None)
111112
)
112113
return cache
113114

@@ -123,9 +124,10 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> Tuple[j
123124
if shard_axis is None and width % num_fsdp_devices == 0:
124125
shard_width_axis = "fsdp"
125126

127+
# CRITICAL FIX: First axis is "data" (Batch), NOT None
126128
x = jax.lax.with_sharding_constraint(
127129
x,
128-
PartitionSpec(None, None, shard_axis, shard_width_axis, None)
130+
PartitionSpec("data", None, shard_axis, shard_width_axis, None)
129131
)
130132

131133
current_padding = list(self._causal_padding)
@@ -238,7 +240,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
238240
if self.mode == "upsample2d":
239241
b, t, h, w, c = x.shape
240242
x = x.reshape(b * t, h, w, c)
241-
x = self.resample(x) # Sequential
243+
x = self.resample(x)
242244
h_new, w_new, c_new = x.shape[1:]
243245
x = x.reshape(b, t, h_new, w_new, c_new)
244246

@@ -260,14 +262,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
260262
elif self.mode == "downsample2d":
261263
b, t, h, w, c = x.shape
262264
x = x.reshape(b * t, h, w, c)
263-
x, _ = self.resample(x, None) # ZeroPaddedConv2D
265+
x, _ = self.resample(x, None)
264266
h_new, w_new, c_new = x.shape[1:]
265267
x = x.reshape(b, t, h_new, w_new, c_new)
266268

267269
elif self.mode == "downsample3d":
268270
b, t, h, w, c = x.shape
269271
x = x.reshape(b * t, h, w, c)
270-
x, _ = self.resample(x, None) # ZeroPaddedConv2D
272+
x, _ = self.resample(x, None)
271273
h_new, w_new, c_new = x.shape[1:]
272274
x = x.reshape(b, t, h_new, w_new, c_new)
273275

@@ -561,7 +563,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
561563

562564

563565
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
564-
def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult: Tuple[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], dropout: float = 0.0, latents_mean: List[float] = [], latents_std: List[float] = [], mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None):
566+
def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult: Tuple[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], dropout=0.0, latents_mean: List[float] = [], latents_std: List[float] = [], mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None):
565567
self.z_dim = z_dim
566568
self.temperal_downsample = temperal_downsample
567569
self.temporal_upsample = temperal_downsample[::-1]
@@ -611,9 +613,10 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
611613
def scan_fn(carry, input_slice):
612614
# Expand Time dimension for Conv3d
613615
input_slice = jnp.expand_dims(input_slice, 1)
616+
# OPTIMIZATION: Force bfloat16 accumulation within the scan
617+
# to save memory on the massive output buffer
614618
out_slice, new_carry = self.decoder(input_slice, carry)
615619
out_slice = out_slice.astype(jnp.bfloat16)
616-
# Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
617620
return new_carry, out_slice
618621

619622
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)

0 commit comments

Comments
 (0)