Skip to content

Commit 712cc4a

Browse files
committed
padding fix
1 parent a7b83a8 commit 712cc4a

1 file changed

Lines changed: 81 additions & 57 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -95,47 +95,37 @@ def __init__(
9595
)
9696

9797
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
98-
current_padding = list(self._causal_padding) # Mutable copy
98+
current_padding = list(self._causal_padding)
9999
padding_needed = self._depth_padding_before
100100

101101
if cache_x is not None and padding_needed > 0:
102-
# Ensure cache has same spatial/channel dims, potentially different depth
103102
assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch"
104103
cache_len = cache_x.shape[1]
105-
x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D)
106-
104+
x = jnp.concatenate([cache_x, x], axis=1)
107105
padding_needed -= cache_len
108106
if padding_needed < 0:
109-
# Cache longer than needed padding, trim from start
110107
x = x[:, -padding_needed:, ...]
111-
current_padding[1] = (0, 0) # No explicit padding needed now
108+
current_padding[1] = (0, 0)
112109
else:
113-
# Update depth padding needed
114110
current_padding[1] = (padding_needed, 0)
115111

116-
# Apply padding if any dimension requires it
117112
padding_to_apply = tuple(current_padding)
118113
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
119114
x_internal = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
120115
else:
121116
x_internal = x
122117

123-
h_dim_after_conv_padding = x_internal.shape[2]
124-
pad_h_fsdp = 0
125-
if self.mesh and 'fsdp' in self.mesh.axis_names:
126-
fsdp_size = self.mesh.shape['fsdp']
127-
if fsdp_size > 1:
128-
if h_dim_after_conv_padding % fsdp_size != 0:
129-
pad_h_fsdp = fsdp_size - (h_dim_after_conv_padding % fsdp_size)
130-
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
131-
x_internal = jnp.pad(x_internal, h_padding, mode="constant", constant_values=0.0)
118+
# REMOVED FSDP PADDING LOGIC FROM HERE
119+
# Sharding constraints are fine, but JAX will error if not divisible.
120+
# This will be handled in the calling block.
132121
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
133122
x_internal = jax.lax.with_sharding_constraint(x_internal, P(None, None, 'fsdp', None, None))
134123

135124
out = self.conv(x_internal)
136125
return out
137126

138127

128+
139129
class WanRMS_norm(nnx.Module):
140130

141131
def __init__(
@@ -328,6 +318,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
328318
# Input x: (N, D, H, W, C), assume C = self.dim
329319
b, t, h, w, c = x.shape
330320
assert c == self.dim
321+
original_h = h
331322

332323
if self.mode == "upsample3d":
333324
if feat_cache is not None:
@@ -351,32 +342,37 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
351342
x = x.reshape(b, t, h, w, 2, c)
352343
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
353344
x = x.reshape(b, t * 2, h, w, c)
345+
# Update t and h as they might have changed in upsample3d
354346
t = x.shape[1]
355347
h = x.shape[2]
348+
# original_h remains the height *before* this block's operations
356349

357350
x_reshaped = x.reshape(b * t, h, w, c)
351+
current_h = x_reshaped.shape[1]
358352

359-
original_h = x_reshaped.shape[1]
353+
# --- FSDP Spatial Padding ---
360354
pad_h_fsdp = 0
361355
if self.mesh and 'fsdp' in self.mesh.axis_names:
362356
fsdp_size = self.mesh.shape['fsdp']
363357
if fsdp_size > 1:
364-
if original_h % fsdp_size != 0:
365-
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
358+
if current_h % fsdp_size != 0:
359+
pad_h_fsdp = fsdp_size - (current_h % fsdp_size)
366360
h_padding = ((0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
367361
x_reshaped = jnp.pad(x_reshaped, h_padding, mode="constant", constant_values=0.0)
362+
# --- End FSDP Spatial Padding ---
368363

369364
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
370365
x_reshaped = jax.lax.with_sharding_constraint(x_reshaped, P(None, 'fsdp', None, None))
371366

372367
resampled_x = self.resample(x_reshaped)
373368

369+
# --- FSDP Spatial Slicing ---
374370
if pad_h_fsdp > 0:
375371
if "upsample" in self.mode:
376372
scale_factor_h = 1.0
377373
if isinstance(self.resample, nnx.Sequential) and isinstance(self.resample.layers[0], WanUpsample):
378374
scale_factor_h = self.resample.layers[0].scale_factor[0]
379-
target_h = int(original_h * scale_factor_h)
375+
target_h = int(current_h * scale_factor_h)
380376
resampled_x = resampled_x[:, :target_h, :, :]
381377
elif "downsample" in self.mode:
382378
stride_h = 1
@@ -386,8 +382,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
386382
stride_h = self.resample.strides[0]
387383

388384
if stride_h > 1:
389-
target_h = original_h // stride_h
385+
# kernel_size and padding affect output size,
386+
# For "VALID" in ZeroPaddedConv2D (which has no other padding), out = (in - kernel + stride) // stride
387+
# Since we added padding for FSDP, we want the size as if no FSDP padding was added.
388+
k_h = self.resample.conv.kernel_size[0]
389+
target_h = (current_h - k_h + stride_h) // stride_h
390390
resampled_x = resampled_x[:, :target_h, :, :]
391+
# If stride_h is 1, no slicing needed as the size doesn't shrink.
392+
# --- End FSDP Spatial Slicing ---
391393

392394
h_new, w_new, c_new = resampled_x.shape[1:]
393395
x = resampled_x.reshape(b, t, h_new, w_new, c_new)
@@ -403,10 +405,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
403405
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
404406
feat_cache[idx] = cache_x
405407
feat_idx[0] += 1
406-
407408
return x
408409

409410

411+
410412
class WanResidualBlock(nnx.Module):
411413

412414
def __init__(
@@ -421,6 +423,7 @@ def __init__(
421423
weights_dtype: jnp.dtype = jnp.float32,
422424
precision: jax.lax.Precision = None,
423425
):
426+
self.mesh = mesh
424427
self.nonlinearity = get_activation(non_linearity)
425428

426429
# layers
@@ -464,39 +467,54 @@ def __init__(
464467
)
465468

466469
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
467-
# Apply shortcut connection
468-
h = self.conv_shortcut(x)
470+
original_shape = x.shape
471+
original_h = original_shape[2]
472+
original_w = original_shape[3]
473+
pad_h_fsdp = 0
474+
pad_w_fsdp = 0
475+
x_padded = x
469476

470-
x = self.norm1(x)
471-
x = self.nonlinearity(x)
477+
if self.mesh and 'fsdp' in self.mesh.axis_names:
478+
fsdp_size = self.mesh.shape['fsdp']
479+
if fsdp_size > 1:
480+
if original_h % fsdp_size != 0:
481+
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
482+
# Assuming width is not sharded on fsdp, add if needed
483+
# if original_w % fsdp_size != 0:
484+
# pad_w_fsdp = fsdp_size - (original_w % fsdp_size)
472485

473-
if feat_cache is not None:
474-
idx = feat_idx[0]
475-
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
476-
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
477-
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
478-
x = self.conv1(x, feat_cache[idx], idx)
479-
feat_cache[idx] = cache_x
480-
feat_idx[0] += 1
481-
else:
482-
x = self.conv1(x)
486+
if pad_h_fsdp > 0 or pad_w_fsdp > 0:
487+
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, pad_w_fsdp), (0, 0))
488+
x_padded = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
489+
490+
h = self.conv_shortcut(x_padded)
491+
492+
temp_x = self.norm1(x_padded)
493+
temp_x = self.nonlinearity(temp_x)
494+
temp_x = self.conv1(temp_x, cache_x=feat_cache[idx] if feat_cache else None)
495+
temp_x = self.norm2(temp_x)
496+
temp_x = self.nonlinearity(temp_x)
497+
temp_x = self.conv2(temp_x, cache_x=feat_cache[idx] if feat_cache else None)
498+
499+
# --- Crop temp_x to match h's spatial dimensions ---
500+
h_height, h_width = h.shape[2], h.shape[3]
501+
x_height, x_width = temp_x.shape[2], temp_x.shape[3]
502+
503+
if x_height > h_height:
504+
ch = (x_height - h_height) // 2
505+
temp_x = temp_x[:, :, ch:ch + h_height, :, :]
506+
if x_width > h_width:
507+
cw = (x_width - h_width) // 2
508+
temp_x = temp_x[:, :, :, cw:cw + h_width, :]
509+
# --- End Crop ---
510+
511+
res_x = temp_x + h
512+
513+
if pad_h_fsdp > 0 or pad_w_fsdp > 0:
514+
res_x = res_x[:, :, :original_h, :original_w, :]
515+
return res_x
483516

484-
x = self.norm2(x)
485-
x = self.nonlinearity(x)
486-
idx = feat_idx[0]
487517

488-
if feat_cache is not None:
489-
idx = feat_idx[0]
490-
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
491-
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
492-
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
493-
x = self.conv2(x, feat_cache[idx])
494-
feat_cache[idx] = cache_x
495-
feat_idx[0] += 1
496-
else:
497-
x = self.conv2(x)
498-
x = x + h
499-
return x
500518

501519

502520
class WanAttentionBlock(nnx.Module):
@@ -535,25 +553,27 @@ def __init__(
535553
)
536554

537555
def __call__(self, x: jax.Array):
538-
539556
identity = x
540557
batch_size, time, height, width, channels = x.shape
541558
original_h = height
542559

560+
# --- FSDP Spatial Padding ---
543561
pad_h_fsdp = 0
562+
x_padded = x
544563
if self.mesh and 'fsdp' in self.mesh.axis_names:
545564
fsdp_size = self.mesh.shape['fsdp']
546565
if fsdp_size > 1:
547566
if original_h % fsdp_size != 0:
548567
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
549568
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
550-
x = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
569+
x_padded = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
570+
# --- End FSDP Spatial Padding ---
551571

552572
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
553-
x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None))
573+
x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None))
554574

555-
current_height = x.shape[2]
556-
x_reshaped = x.reshape(batch_size * time, current_height, width, channels)
575+
current_height = x_padded.shape[2]
576+
x_reshaped = x_padded.reshape(batch_size * time, current_height, width, channels)
557577
x_normed = self.norm(x_reshaped)
558578

559579
qkv = self.to_qkv(x_normed)
@@ -568,12 +588,16 @@ def __call__(self, x: jax.Array):
568588

569589
x_proj = self.proj(attn_out)
570590
x_proj = x_proj.reshape(batch_size, time, current_height, width, channels)
591+
592+
# --- FSDP Spatial Slicing ---
571593
if pad_h_fsdp > 0:
572594
x_proj = x_proj[:, :, :original_h, :, :]
595+
# --- End FSDP Spatial Slicing ---
573596

574597
return x_proj + identity
575598

576599

600+
577601
class WanMidBlock(nnx.Module):
578602

579603
def __init__(

0 commit comments

Comments
 (0)