Skip to content

Commit 34ebdbe

Browse files
debug statements
1 parent 40d423d commit 34ebdbe

3 files changed

Lines changed: 107 additions & 27 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..modeling_flax_utils import FlaxModelMixin
2424
from ... import common_types
2525
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
26-
26+
import numpy as np
2727
BlockSizes = common_types.BlockSizes
2828

2929
CACHE_T = 2
@@ -93,33 +93,51 @@ def __init__(
9393
rngs=rngs,
9494
)
9595

96-
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Array:
96+
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
97+
print("wanCausalConv3d, x min: ", np.min(x))
98+
print("wanCausalConv3d, x max: ", np.max(x))
9799
current_padding = list(self._causal_padding) # Mutable copy
98100
padding_needed = self._depth_padding_before
99101

100102
if cache_x is not None and padding_needed > 0:
103+
print("WanCausalConv3d, cache.shape: ", cache_x.shape)
104+
print("wanCausalConv3d, cache_x min: ", np.min(cache_x))
105+
print("wanCausalConv3d, cache_x max: ", np.max(cache_x))
101106
# Ensure cache has same spatial/channel dims, potentially different depth
102107
assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch"
103108
cache_len = cache_x.shape[1]
104109
x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D)
105110

106111
padding_needed -= cache_len
107112
if padding_needed < 0:
113+
print("wanCausanConv3d, padding_needed < 0")
108114
# Cache longer than needed padding, trim from start
109115
x = x[:, -padding_needed:, ...]
110116
current_padding[1] = (0, 0) # No explicit padding needed now
111117
else:
112118
# Update depth padding needed
119+
print("wanCausanConv3d, padding_needed > 0")
113120
current_padding[1] = (padding_needed, 0)
114121

115122
# Apply padding if any dimension requires it
116123
padding_to_apply = tuple(current_padding)
124+
print("WanCausalConv3d, before padding x shape: ", x.shape)
117125
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
126+
print("WanCausalConv3d, applying padding")
118127
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
119128
else:
129+
print("WanCausalConv3d, NOT applying padding")
120130
x_padded = x
121131

132+
print("WanCausalConv3d, x shape: ", x_padded.shape)
133+
print("wanCausalConv3d, x min: ", np.min(x_padded))
134+
print("wanCausalConv3d, x max: ", np.max(x_padded))
135+
# if idx == 12:
136+
# breakpoint()
122137
out = self.conv(x_padded)
138+
print("WanCausalConv3d, after conv, x shape: ", out.shape)
139+
print("wanCausalConv3d, x min: ", np.min(out))
140+
print("wanCausalConv3d, x max: ", np.max(out))
123141
return out
124142

125143

@@ -346,31 +364,48 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
346364

347365
if feat_cache is not None:
348366
idx = feat_idx[0]
367+
print("Before conv1, idx: ", idx)
349368
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
350369
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
351370
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
352-
353-
x = self.conv1(x, feat_cache[idx])
371+
x = self.conv1(x, feat_cache[idx], idx)
372+
# if idx == 4:
373+
# breakpoint()
354374
feat_cache[idx] = cache_x
355375
feat_idx[0] += 1
356376
else:
357377
x = self.conv1(x)
358378

359379
x = self.norm2(x)
360380
x = self.nonlinearity(x)
381+
idx = feat_idx[0]
382+
# if idx == 4:
383+
# breakpoint()
361384

362385
if feat_cache is not None:
363386
idx = feat_idx[0]
387+
print("Residual block, idx: ", idx)
388+
# if idx == 14:
389+
# breakpoint()
390+
print("cache_x min: ", np.min(cache_x))
391+
print("cache_x max: ", np.max(cache_x))
364392
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
365393
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
366394
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
395+
print("cache_x min: ", np.min(cache_x))
396+
print("cache_x max: ", np.max(cache_x))
397+
#breakpoint()
367398
x = self.conv2(x, feat_cache[idx])
368399
feat_cache[idx] = cache_x
369400
feat_idx[0] += 1
370401
else:
371402
x = self.conv2(x)
372-
373-
return x + h
403+
print("before conv shortcut add: x min", np.min(x))
404+
print("before conv shortcut add: x max", np.max(x))
405+
x = x + h
406+
print("after conv shortcut add: x min: ", np.min(x))
407+
print("after conv shortcut add: x max: ", np.max(x))
408+
return x
374409

375410

376411
class WanAttentionBlock(nnx.Module):
@@ -382,26 +417,51 @@ def __init__(self, dim: int, rngs: nnx.Rngs):
382417
self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs)
383418

384419
def __call__(self, x: jax.Array):
385-
batch_size, time, height, width, channels = x.shape
420+
386421
identity = x
422+
batch_size, time, height, width, channels = x.shape
387423

388424
x = x.reshape(batch_size * time, height, width, channels)
389425
x = self.norm(x)
390426

391427
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
392-
393-
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
428+
#breakpoint()
429+
#qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
430+
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
394431
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
395-
q, k, v = jnp.split(qkv, 3, axis=-1)
396-
397-
x = jax.nn.dot_product_attention(q, k, v)
432+
print("qkv min: ", np.min(qkv))
433+
print("qkv max: ", np.max(qkv))
434+
#q, k, v = jnp.split(qkv, 3, axis=-1)
435+
q, k, v = jnp.split(qkv, 3, axis=-2)
436+
print("q min: ", np.min(q))
437+
print("q max: ", np.max(q))
438+
print("k min: ", np.min(k))
439+
print("k min: ", np.max(k))
440+
print("v min: ", np.min(v))
441+
print("v min: ", np.max(v))
442+
#breakpoint()
443+
q = jnp.transpose(q, (0, 1, 3, 2))
444+
k = jnp.transpose(k, (0, 1, 3, 2))
445+
v = jnp.transpose(v, (0, 1, 3, 2))
446+
import torch
447+
import torch.nn.functional as F
448+
q = torch.tensor(np.array(q, dtype=np.float32))
449+
k = torch.tensor(np.array(k, dtype=np.float32))
450+
v = torch.tensor(np.array(v, dtype=np.float32))
451+
#x = jax.nn.dot_product_attention(q, k, v)
452+
x = F.scaled_dot_product_attention(q, k, v)
453+
print("attn min: ", torch.min(x))
454+
print("attn max: ", torch.max(x))
455+
#breakpoint()
456+
x = jnp.array(x.detach().numpy())
398457
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
399458

400459
# output projection
401460
x = self.proj(x)
402-
461+
#breakpoint()
403462
# Reshape back
404463
x = x.reshape(batch_size, time, height, width, channels)
464+
#breakpoint()
405465

406466
return x + identity
407467

@@ -419,11 +479,20 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity
419479
self.resnets = resnets
420480

421481
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
482+
print("WanMidblock...")
422483
x = self.resnets[0](x, feat_cache, feat_idx)
484+
print("WanMidBlock resnets[0], x min: ", np.min(x))
485+
print("WanMidBlock resnets[0], x max: ", np.max(x))
423486
for attn, resnet in zip(self.attentions, self.resnets[1:]):
487+
print("WanMidBlock, for loop, attn len: ", len(self.attentions))
488+
print("WanMidBlock, for loop, resnets len: ", len(self.resnets))
424489
if attn is not None:
425490
x = attn(x)
491+
print("WanMidBlock attn[0], x min: ", np.min(x))
492+
print("WanMidBlock attn[0], x max: ", np.max(x))
426493
x = resnet(x, feat_cache, feat_idx)
494+
print("WanMidBlock resnets[i], x min: ", np.min(x))
495+
print("WanMidBlock resnets[i], x max: ", np.max(x))
427496
return x
428497

429498

@@ -589,7 +658,7 @@ def __init__(
589658
self,
590659
rngs: nnx.Rngs,
591660
dim: int = 128,
592-
z_dim: int = 128,
661+
z_dim: int = 4,
593662
dim_mult: List[int] = [1, 2, 4, 4],
594663
num_res_blocks: int = 2,
595664
attn_scales=List[float],
@@ -662,7 +731,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
662731

663732
## middle
664733
x = self.mid_block(x, feat_cache, feat_idx)
665-
734+
#breakpoint()
666735
## upsamples
667736
for up_block in self.up_blocks:
668737
x = up_block(x, feat_cache, feat_idx)
@@ -810,7 +879,6 @@ def _encode(self, x: jax.Array):
810879
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
811880
enc = jnp.concatenate([mu, logvar], axis=-1)
812881
self.clear_cache()
813-
# return enc
814882
return enc
815883

816884
def encode(
@@ -833,10 +901,22 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
833901
out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
834902
else:
835903
out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
836-
837904
out = jnp.concatenate([out, out_], axis=1)
838-
839-
out = jnp.clip(out, a_min=-1.0, a_max=1.0)
905+
print("out_.shape: ", out_.shape)
906+
print("out_ min: ", np.min(out_))
907+
print("out_ max: ", np.max(out_))
908+
print("out.shape: ", out.shape)
909+
print("out min: ", np.min(out))
910+
print("out max: ", np.max(out))
911+
for i in range(len(self._feat_map)):
912+
if isinstance(self._feat_map[i], jax.Array):
913+
print("i: ", i)
914+
print("min: ", np.min(self._feat_map[i]))
915+
print("max: ", np.max(self._feat_map[i]))
916+
else:
917+
print(f"feat_map[{i}] : {self._feat_map[i]}")
918+
# breakpoint()
919+
out = jnp.clip(out, min=-1.0, max=1.0)
840920
self.clear_cache()
841921
if not return_dict:
842922
return (out,)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
2626
with jax.default_device(device):
2727
if hf_download:
2828
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors")
29-
#breakpoint()
3029
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
3130

3231
if ckpt_path is not None:

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
CACHE_T = 2
4242

43-
4443
class TorchWanRMS_norm(nn.Module):
4544
r"""
4645
A custom RMS normalization layer.
@@ -92,16 +91,18 @@ def __init__(self, dim: int, mode: str) -> None:
9291
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
9392
)
9493
elif mode == "upsample3d":
95-
self.resample = nn.Sequential(
96-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
97-
)
98-
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
94+
# self.resample = nn.Sequential(
95+
# WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
96+
# )
97+
# self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
98+
raise Exception("downsample3d not supported")
9999

100100
elif mode == "downsample2d":
101101
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
102102
elif mode == "downsample3d":
103-
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
104-
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
103+
raise Exception("downsample3d not supported")
104+
#self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
105+
#self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
105106

106107
else:
107108
self.resample = nn.Identity()

0 commit comments

Comments
 (0)