Skip to content

Commit 10aea6a

Browse files
committed
spatial sharding added
1 parent 712cc4a commit 10aea6a

1 file changed

Lines changed: 65 additions & 150 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 65 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def __init__(
5757
weights_dtype: jnp.dtype = jnp.float32,
5858
precision: jax.lax.Precision = None,
5959
):
60-
61-
self.mesh = mesh
6260
self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
6361
self.stride = _canonicalize_tuple(stride, 3, "stride")
6462
padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts
@@ -95,37 +93,35 @@ def __init__(
9593
)
9694

9795
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
98-
current_padding = list(self._causal_padding)
96+
current_padding = list(self._causal_padding) # Mutable copy
9997
padding_needed = self._depth_padding_before
10098

10199
if cache_x is not None and padding_needed > 0:
100+
# Ensure cache has same spatial/channel dims, potentially different depth
102101
assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch"
103102
cache_len = cache_x.shape[1]
104-
x = jnp.concatenate([cache_x, x], axis=1)
103+
x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D)
104+
105105
padding_needed -= cache_len
106106
if padding_needed < 0:
107+
# Cache longer than needed padding, trim from start
107108
x = x[:, -padding_needed:, ...]
108-
current_padding[1] = (0, 0)
109+
current_padding[1] = (0, 0) # No explicit padding needed now
109110
else:
111+
# Update depth padding needed
110112
current_padding[1] = (padding_needed, 0)
111113

114+
# Apply padding if any dimension requires it
112115
padding_to_apply = tuple(current_padding)
113116
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
114-
x_internal = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
117+
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
115118
else:
116-
x_internal = x
117-
118-
# REMOVED FSDP PADDING LOGIC FROM HERE
119-
# Sharding constraints are fine, but JAX will error if not divisible.
120-
# This will be handled in the calling block.
121-
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
122-
x_internal = jax.lax.with_sharding_constraint(x_internal, P(None, None, 'fsdp', None, None))
123-
124-
out = self.conv(x_internal)
119+
x_padded = x
120+
x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None))
121+
out = self.conv(x_padded)
125122
return out
126123

127124

128-
129125
class WanRMS_norm(nnx.Module):
130126

131127
def __init__(
@@ -229,7 +225,6 @@ def __init__(
229225
weights_dtype: jnp.dtype = jnp.float32,
230226
precision: jax.lax.Precision = None,
231227
):
232-
self.mesh = mesh
233228
self.dim = dim
234229
self.mode = mode
235230
self.time_conv = nnx.data(None)
@@ -318,7 +313,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
318313
# Input x: (N, D, H, W, C), assume C = self.dim
319314
b, t, h, w, c = x.shape
320315
assert c == self.dim
321-
original_h = h
322316

323317
if self.mode == "upsample3d":
324318
if feat_cache is not None:
@@ -342,57 +336,12 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
342336
x = x.reshape(b, t, h, w, 2, c)
343337
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
344338
x = x.reshape(b, t * 2, h, w, c)
345-
# Update t and h as they might have changed in upsample3d
346-
t = x.shape[1]
347-
h = x.shape[2]
348-
# original_h remains the height *before* this block's operations
349-
350-
x_reshaped = x.reshape(b * t, h, w, c)
351-
current_h = x_reshaped.shape[1]
352-
353-
# --- FSDP Spatial Padding ---
354-
pad_h_fsdp = 0
355-
if self.mesh and 'fsdp' in self.mesh.axis_names:
356-
fsdp_size = self.mesh.shape['fsdp']
357-
if fsdp_size > 1:
358-
if current_h % fsdp_size != 0:
359-
pad_h_fsdp = fsdp_size - (current_h % fsdp_size)
360-
h_padding = ((0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
361-
x_reshaped = jnp.pad(x_reshaped, h_padding, mode="constant", constant_values=0.0)
362-
# --- End FSDP Spatial Padding ---
363-
364-
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
365-
x_reshaped = jax.lax.with_sharding_constraint(x_reshaped, P(None, 'fsdp', None, None))
366-
367-
resampled_x = self.resample(x_reshaped)
368-
369-
# --- FSDP Spatial Slicing ---
370-
if pad_h_fsdp > 0:
371-
if "upsample" in self.mode:
372-
scale_factor_h = 1.0
373-
if isinstance(self.resample, nnx.Sequential) and isinstance(self.resample.layers[0], WanUpsample):
374-
scale_factor_h = self.resample.layers[0].scale_factor[0]
375-
target_h = int(current_h * scale_factor_h)
376-
resampled_x = resampled_x[:, :target_h, :, :]
377-
elif "downsample" in self.mode:
378-
stride_h = 1
379-
if isinstance(self.resample, ZeroPaddedConv2D):
380-
stride_h = self.resample.conv.strides[0]
381-
elif isinstance(self.resample, nnx.Conv):
382-
stride_h = self.resample.strides[0]
383-
384-
if stride_h > 1:
385-
# kernel_size and padding affect output size,
386-
# For "VALID" in ZeroPaddedConv2D (which has no other padding), out = (in - kernel + stride) // stride
387-
# Since we added padding for FSDP, we want the size as if no FSDP padding was added.
388-
k_h = self.resample.conv.kernel_size[0]
389-
target_h = (current_h - k_h + stride_h) // stride_h
390-
resampled_x = resampled_x[:, :target_h, :, :]
391-
# If stride_h is 1, no slicing needed as the size doesn't shrink.
392-
# --- End FSDP Spatial Slicing ---
393-
394-
h_new, w_new, c_new = resampled_x.shape[1:]
395-
x = resampled_x.reshape(b, t, h_new, w_new, c_new)
339+
t = x.shape[1]
340+
x = x.reshape(b * t, h, w, c)
341+
x = jax.lax.with_sharding_constraint(x, P(None, 'fsdp', None, None))
342+
x = self.resample(x)
343+
h_new, w_new, c_new = x.shape[1:]
344+
x = x.reshape(b, t, h_new, w_new, c_new)
396345

397346
if self.mode == "downsample3d":
398347
if feat_cache is not None:
@@ -405,8 +354,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
405354
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
406355
feat_cache[idx] = cache_x
407356
feat_idx[0] += 1
408-
return x
409357

358+
return x
410359

411360

412361
class WanResidualBlock(nnx.Module):
@@ -423,7 +372,6 @@ def __init__(
423372
weights_dtype: jnp.dtype = jnp.float32,
424373
precision: jax.lax.Precision = None,
425374
):
426-
self.mesh = mesh
427375
self.nonlinearity = get_activation(non_linearity)
428376

429377
# layers
@@ -467,54 +415,39 @@ def __init__(
467415
)
468416

469417
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
470-
original_shape = x.shape
471-
original_h = original_shape[2]
472-
original_w = original_shape[3]
473-
pad_h_fsdp = 0
474-
pad_w_fsdp = 0
475-
x_padded = x
476-
477-
if self.mesh and 'fsdp' in self.mesh.axis_names:
478-
fsdp_size = self.mesh.shape['fsdp']
479-
if fsdp_size > 1:
480-
if original_h % fsdp_size != 0:
481-
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
482-
# Assuming width is not sharded on fsdp, add if needed
483-
# if original_w % fsdp_size != 0:
484-
# pad_w_fsdp = fsdp_size - (original_w % fsdp_size)
485-
486-
if pad_h_fsdp > 0 or pad_w_fsdp > 0:
487-
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, pad_w_fsdp), (0, 0))
488-
x_padded = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
489-
490-
h = self.conv_shortcut(x_padded)
418+
# Apply shortcut connection
419+
h = self.conv_shortcut(x)
491420

492-
temp_x = self.norm1(x_padded)
493-
temp_x = self.nonlinearity(temp_x)
494-
temp_x = self.conv1(temp_x, cache_x=feat_cache[idx] if feat_cache else None)
495-
temp_x = self.norm2(temp_x)
496-
temp_x = self.nonlinearity(temp_x)
497-
temp_x = self.conv2(temp_x, cache_x=feat_cache[idx] if feat_cache else None)
498-
499-
# --- Crop temp_x to match h's spatial dimensions ---
500-
h_height, h_width = h.shape[2], h.shape[3]
501-
x_height, x_width = temp_x.shape[2], temp_x.shape[3]
502-
503-
if x_height > h_height:
504-
ch = (x_height - h_height) // 2
505-
temp_x = temp_x[:, :, ch:ch + h_height, :, :]
506-
if x_width > h_width:
507-
cw = (x_width - h_width) // 2
508-
temp_x = temp_x[:, :, :, cw:cw + h_width, :]
509-
# --- End Crop ---
510-
511-
res_x = temp_x + h
421+
x = self.norm1(x)
422+
x = self.nonlinearity(x)
512423

513-
if pad_h_fsdp > 0 or pad_w_fsdp > 0:
514-
res_x = res_x[:, :, :original_h, :original_w, :]
515-
return res_x
424+
if feat_cache is not None:
425+
idx = feat_idx[0]
426+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
427+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
428+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
429+
x = self.conv1(x, feat_cache[idx], idx)
430+
feat_cache[idx] = cache_x
431+
feat_idx[0] += 1
432+
else:
433+
x = self.conv1(x)
516434

435+
x = self.norm2(x)
436+
x = self.nonlinearity(x)
437+
idx = feat_idx[0]
517438

439+
if feat_cache is not None:
440+
idx = feat_idx[0]
441+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
442+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
443+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
444+
x = self.conv2(x, feat_cache[idx])
445+
feat_cache[idx] = cache_x
446+
feat_idx[0] += 1
447+
else:
448+
x = self.conv2(x)
449+
x = x + h
450+
return x
518451

519452

520453
class WanAttentionBlock(nnx.Module):
@@ -528,7 +461,6 @@ def __init__(
528461
weights_dtype: jnp.dtype = jnp.float32,
529462
precision: jax.lax.Precision = None,
530463
):
531-
self.mesh = mesh
532464
self.dim = dim
533465
self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False)
534466
self.to_qkv = nnx.Conv(
@@ -553,49 +485,32 @@ def __init__(
553485
)
554486

555487
def __call__(self, x: jax.Array):
488+
556489
identity = x
557490
batch_size, time, height, width, channels = x.shape
558-
original_h = height
559-
560-
# --- FSDP Spatial Padding ---
561-
pad_h_fsdp = 0
562-
x_padded = x
563-
if self.mesh and 'fsdp' in self.mesh.axis_names:
564-
fsdp_size = self.mesh.shape['fsdp']
565-
if fsdp_size > 1:
566-
if original_h % fsdp_size != 0:
567-
pad_h_fsdp = fsdp_size - (original_h % fsdp_size)
568-
h_padding = ((0, 0), (0, 0), (0, pad_h_fsdp), (0, 0), (0, 0))
569-
x_padded = jnp.pad(x, h_padding, mode="constant", constant_values=0.0)
570-
# --- End FSDP Spatial Padding ---
571-
572-
if self.mesh and 'fsdp' in self.mesh.axis_names and self.mesh.shape['fsdp'] > 1:
573-
x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None))
574-
575-
current_height = x_padded.shape[2]
576-
x_reshaped = x_padded.reshape(batch_size * time, current_height, width, channels)
577-
x_normed = self.norm(x_reshaped)
578-
579-
qkv = self.to_qkv(x_normed)
491+
492+
x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None))
493+
494+
x = x.reshape(batch_size * time, height, width, channels)
495+
x = self.norm(x)
496+
497+
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
498+
# qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
580499
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
581500
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
582501
q, k, v = jnp.split(qkv, 3, axis=-2)
583502
q = jnp.transpose(q, (0, 1, 3, 2))
584503
k = jnp.transpose(k, (0, 1, 3, 2))
585504
v = jnp.transpose(v, (0, 1, 3, 2))
586-
attn_out = jax.nn.dot_product_attention(q, k, v)
587-
attn_out = jnp.squeeze(attn_out, 1).reshape(batch_size * time, current_height, width, channels)
588-
589-
x_proj = self.proj(attn_out)
590-
x_proj = x_proj.reshape(batch_size, time, current_height, width, channels)
591-
592-
# --- FSDP Spatial Slicing ---
593-
if pad_h_fsdp > 0:
594-
x_proj = x_proj[:, :, :original_h, :, :]
595-
# --- End FSDP Spatial Slicing ---
505+
x = jax.nn.dot_product_attention(q, k, v)
506+
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
596507

597-
return x_proj + identity
508+
# output projection
509+
x = self.proj(x)
510+
# Reshape back
511+
x = x.reshape(batch_size, time, height, width, channels)
598512

513+
return x + identity
599514

600515

601516
class WanMidBlock(nnx.Module):
@@ -1234,4 +1149,4 @@ def decode(
12341149
decoded = self._decode(z, feat_cache).sample
12351150
if not return_dict:
12361151
return (decoded,)
1237-
return FlaxDecoderOutput(sample=decoded)
1152+
return FlaxDecoderOutput(sample=decoded)

0 commit comments

Comments
 (0)