Skip to content

Commit a7b83a8

Browse files
committed
padding added
1 parent 46d392f commit a7b83a8

1 file changed

Lines changed: 83 additions & 22 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(
5757
weights_dtype: jnp.dtype = jnp.float32,
5858
precision: jax.lax.Precision = None,
5959
):
60+
61+
self.mesh = mesh
6062
self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
6163
self.stride = _canonicalize_tuple(stride, 3, "stride")
6264
padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts
@@ -114,11 +116,23 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
114116
# Apply padding if any dimension requires it
115117
padding_to_apply = tuple(current_padding)
116118
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
117-
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
119+
x_internal = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
118120
else:
119-
x_padded = x
120-
x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None))
121-
out = self.conv(x_padded)
121+
x_internal = x
122+
123+
h_dim_after_conv_padding = x_internal.shape[2]
124+
pad_h_fsdp = 0
125+
if self.mesh and 'fsdp' in self.mesh.axis_names:
126+
fsdp_size = self.mesh.shape['fsdp']
127+
if fsdp_size > 1:
128+
if h_dim_after_conv_padding % fsdp_size != 0:
129+
pad_h_fsdp = fsdp_size - (h_dim_after_conv_padding % fsdp_size)
130+
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
131+
x_internal = jnp.pad(x_internal, h_padding, mode="constant", constant_values=0.0)
132+
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
133+
x_internal = jax.lax.with_sharding_constraint(x_internal, P(None, None, 'fsdp', None, None))
134+
135+
out = self.conv(x_internal)
122136
return out
123137

124138

@@ -225,6 +239,7 @@ def __init__(
225239
weights_dtype: jnp.dtype = jnp.float32,
226240
precision: jax.lax.Precision = None,
227241
):
242+
self.mesh = mesh
228243
self.dim = dim
229244
self.mode = mode
230245
self.time_conv = nnx.data(None)
@@ -336,12 +351,46 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
336351
x = x.reshape(b, t, h, w, 2, c)
337352
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
338353
x = x.reshape(b, t * 2, h, w, c)
339-
t = x.shape[1]
340-
x = x.reshape(b * t, h, w, c)
341-
x = jax.lax.with_sharding_constraint(x, P(None, 'fsdp', None, None))
342-
x = self.resample(x)
343-
h_new, w_new, c_new = x.shape[1:]
344-
x = x.reshape(b, t, h_new, w_new, c_new)
354+
t = x.shape[1]
355+
h = x.shape[2]
356+
357+
x_reshaped = x.reshape(b * t, h, w, c)
358+
359+
original_h = x_reshaped.shape[1]
360+
pad_h_fsdp = 0
361+
if self.mesh and 'fsdp' in self.mesh.axis_names:
362+
fsdp_size = self.mesh.shape['fsdp']
363+
if fsdp_size > 1:
364+
if original_h % fsdp_size != 0:
365+
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
366+
h_padding = ((0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
367+
x_reshaped = jnp.pad(x_reshaped, h_padding, mode="constant", constant_values=0.0)
368+
369+
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
370+
x_reshaped = jax.lax.with_sharding_constraint(x_reshaped, P(None, 'fsdp', None, None))
371+
372+
resampled_x = self.resample(x_reshaped)
373+
374+
if pad_h_fsdp > 0:
375+
if "upsample" in self.mode:
376+
scale_factor_h = 1.0
377+
if isinstance(self.resample, nnx.Sequential) and isinstance(self.resample.layers[0], WanUpsample):
378+
scale_factor_h = self.resample.layers[0].scale_factor[0]
379+
target_h = int(original_h * scale_factor_h)
380+
resampled_x = resampled_x[:, :target_h, :, :]
381+
elif "downsample" in self.mode:
382+
stride_h = 1
383+
if isinstance(self.resample, ZeroPaddedConv2D):
384+
stride_h = self.resample.conv.strides[0]
385+
elif isinstance(self.resample, nnx.Conv):
386+
stride_h = self.resample.strides[0]
387+
388+
if stride_h > 1:
389+
target_h = original_h // stride_h
390+
resampled_x = resampled_x[:, :target_h, :, :]
391+
392+
h_new, w_new, c_new = resampled_x.shape[1:]
393+
x = resampled_x.reshape(b, t, h_new, w_new, c_new)
345394

346395
if self.mode == "downsample3d":
347396
if feat_cache is not None:
@@ -461,6 +510,7 @@ def __init__(
461510
weights_dtype: jnp.dtype = jnp.float32,
462511
precision: jax.lax.Precision = None,
463512
):
513+
self.mesh = mesh
464514
self.dim = dim
465515
self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False)
466516
self.to_qkv = nnx.Conv(
@@ -488,29 +538,40 @@ def __call__(self, x: jax.Array):
488538

489539
identity = x
490540
batch_size, time, height, width, channels = x.shape
541+
original_h = height
542+
543+
pad_h_fsdp = 0
544+
if self.mesh and 'fsdp' in self.mesh.axis_names:
545+
fsdp_size = self.mesh.shape['fsdp']
546+
if fsdp_size > 1:
547+
if original_h % fsdp_size != 0:
548+
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
549+
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
550+
x = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
491551

492-
x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None))
552+
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
553+
x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None))
493554

494-
x = x.reshape(batch_size * time, height, width, channels)
495-
x = self.norm(x)
555+
current_height = x.shape[2]
556+
x_reshaped = x.reshape(batch_size * time, current_height, width, channels)
557+
x_normed = self.norm(x_reshaped)
496558

497-
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
498-
# qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
559+
qkv = self.to_qkv(x_normed)
499560
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
500561
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
501562
q, k, v = jnp.split(qkv, 3, axis=-2)
502563
q = jnp.transpose(q, (0, 1, 3, 2))
503564
k = jnp.transpose(k, (0, 1, 3, 2))
504565
v = jnp.transpose(v, (0, 1, 3, 2))
505-
x = jax.nn.dot_product_attention(q, k, v)
506-
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
566+
attn_out = jax.nn.dot_product_attention(q, k, v)
567+
attn_out = jnp.squeeze(attn_out, 1).reshape(batch_size * time, current_height, width, channels)
507568

508-
# output projection
509-
x = self.proj(x)
510-
# Reshape back
511-
x = x.reshape(batch_size, time, height, width, channels)
569+
x_proj = self.proj(attn_out)
570+
x_proj = x_proj.reshape(batch_size, time, current_height, width, channels)
571+
if pad_h_fsdp > 0:
572+
x_proj = x_proj[:, :, :original_h, :, :]
512573

513-
return x + identity
574+
return x_proj + identity
514575

515576

516577
class WanMidBlock(nnx.Module):

0 commit comments

Comments
 (0)