Skip to content

Commit 66146b9

Browse files
fixes jittery decoder frames in vae.
1 parent 04f4909 commit 66146b9

1 file changed

Lines changed: 14 additions & 72 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 14 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -94,50 +94,31 @@ def __init__(
9494
)
9595

9696
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
97-
print("wanCausalConv3d, x min: ", np.min(x))
98-
print("wanCausalConv3d, x max: ", np.max(x))
9997
current_padding = list(self._causal_padding) # Mutable copy
10098
padding_needed = self._depth_padding_before
10199

102100
if cache_x is not None and padding_needed > 0:
103-
print("WanCausalConv3d, cache.shape: ", cache_x.shape)
104-
print("wanCausalConv3d, cache_x min: ", np.min(cache_x))
105-
print("wanCausalConv3d, cache_x max: ", np.max(cache_x))
106101
# Ensure cache has same spatial/channel dims, potentially different depth
107102
assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch"
108103
cache_len = cache_x.shape[1]
109104
x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D)
110105

111106
padding_needed -= cache_len
112107
if padding_needed < 0:
113-
print("wanCausanConv3d, padding_needed < 0")
114108
# Cache longer than needed padding, trim from start
115109
x = x[:, -padding_needed:, ...]
116110
current_padding[1] = (0, 0) # No explicit padding needed now
117111
else:
118112
# Update depth padding needed
119-
print("wanCausanConv3d, padding_needed > 0")
120113
current_padding[1] = (padding_needed, 0)
121114

122115
# Apply padding if any dimension requires it
123116
padding_to_apply = tuple(current_padding)
124-
print("WanCausalConv3d, before padding x shape: ", x.shape)
125117
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
126-
print("WanCausalConv3d, applying padding")
127118
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
128119
else:
129-
print("WanCausalConv3d, NOT applying padding")
130120
x_padded = x
131-
132-
print("WanCausalConv3d, x shape: ", x_padded.shape)
133-
print("wanCausalConv3d, x min: ", np.min(x_padded))
134-
print("wanCausalConv3d, x max: ", np.max(x_padded))
135-
# if idx == 12:
136-
# breakpoint()
137121
out = self.conv(x_padded)
138-
print("WanCausalConv3d, after conv, x shape: ", out.shape)
139-
print("wanCausalConv3d, x min: ", np.min(out))
140-
print("wanCausalConv3d, x max: ", np.max(out))
141122
return out
142123

143124

@@ -300,8 +281,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
300281
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
301282
# cache last frame of last two chunk
302283
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
303-
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
304-
cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1)
284+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
285+
cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1)
305286
if feat_cache[idx] == "Rep":
306287
x = self.time_conv(x)
307288
else:
@@ -364,13 +345,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
364345

365346
if feat_cache is not None:
366347
idx = feat_idx[0]
367-
print("Before conv1, idx: ", idx)
368348
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
369349
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
370350
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
371351
x = self.conv1(x, feat_cache[idx], idx)
372-
# if idx == 4:
373-
# breakpoint()
374352
feat_cache[idx] = cache_x
375353
feat_idx[0] += 1
376354
else:
@@ -379,32 +357,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
379357
x = self.norm2(x)
380358
x = self.nonlinearity(x)
381359
idx = feat_idx[0]
382-
# if idx == 4:
383-
# breakpoint()
384360

385361
if feat_cache is not None:
386362
idx = feat_idx[0]
387-
print("Residual block, idx: ", idx)
388-
# if idx == 14:
389-
# breakpoint()
390-
print("cache_x min: ", np.min(cache_x))
391-
print("cache_x max: ", np.max(cache_x))
392363
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
393364
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
394365
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
395-
print("cache_x min: ", np.min(cache_x))
396-
print("cache_x max: ", np.max(cache_x))
397-
#breakpoint()
398366
x = self.conv2(x, feat_cache[idx])
399367
feat_cache[idx] = cache_x
400368
feat_idx[0] += 1
401369
else:
402370
x = self.conv2(x)
403-
print("before conv shortcut add: x min", np.min(x))
404-
print("before conv shortcut add: x max", np.max(x))
405371
x = x + h
406-
print("after conv shortcut add: x min: ", np.min(x))
407-
print("after conv shortcut add: x max: ", np.max(x))
408372
return x
409373

410374

@@ -428,16 +392,8 @@ def __call__(self, x: jax.Array):
428392
#qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
429393
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
430394
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
431-
print("qkv min: ", np.min(qkv))
432-
print("qkv max: ", np.max(qkv))
433395
#q, k, v = jnp.split(qkv, 3, axis=-1)
434396
q, k, v = jnp.split(qkv, 3, axis=-2)
435-
print("q min: ", np.min(q))
436-
print("q max: ", np.max(q))
437-
print("k min: ", np.min(k))
438-
print("k min: ", np.max(k))
439-
print("v min: ", np.min(v))
440-
print("v min: ", np.max(v))
441397
q = jnp.transpose(q, (0, 1, 3, 2))
442398
k = jnp.transpose(k, (0, 1, 3, 2))
443399
v = jnp.transpose(v, (0, 1, 3, 2))
@@ -446,10 +402,8 @@ def __call__(self, x: jax.Array):
446402

447403
# output projection
448404
x = self.proj(x)
449-
#breakpoint()
450405
# Reshape back
451406
x = x.reshape(batch_size, time, height, width, channels)
452-
#breakpoint()
453407

454408
return x + identity
455409

@@ -467,20 +421,11 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity
467421
self.resnets = resnets
468422

469423
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
470-
print("WanMidblock...")
471424
x = self.resnets[0](x, feat_cache, feat_idx)
472-
print("WanMidBlock resnets[0], x min: ", np.min(x))
473-
print("WanMidBlock resnets[0], x max: ", np.max(x))
474425
for attn, resnet in zip(self.attentions, self.resnets[1:]):
475-
print("WanMidBlock, for loop, attn len: ", len(self.attentions))
476-
print("WanMidBlock, for loop, resnets len: ", len(self.resnets))
477426
if attn is not None:
478427
x = attn(x)
479-
print("WanMidBlock attn[0], x min: ", np.min(x))
480-
print("WanMidBlock attn[0], x max: ", np.max(x))
481428
x = resnet(x, feat_cache, feat_idx)
482-
print("WanMidBlock resnets[i], x min: ", np.min(x))
483-
print("WanMidBlock resnets[i], x max: ", np.max(x))
484429
return x
485430

486431

@@ -888,21 +833,18 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
888833
out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
889834
else:
890835
out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
891-
out = jnp.concatenate([out, out_], axis=1)
892-
print("out_.shape: ", out_.shape)
893-
print("out_ min: ", np.min(out_))
894-
print("out_ max: ", np.max(out_))
895-
print("out.shape: ", out.shape)
896-
print("out min: ", np.min(out))
897-
print("out max: ", np.max(out))
898-
for i in range(len(self._feat_map)):
899-
if isinstance(self._feat_map[i], jax.Array):
900-
print("i: ", i)
901-
print("min: ", np.min(self._feat_map[i]))
902-
print("max: ", np.max(self._feat_map[i]))
903-
else:
904-
print(f"feat_map[{i}] : {self._feat_map[i]}")
905-
# breakpoint()
836+
837+
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
838+
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
839+
# Most likely due to an incorrect reshaping in the decoder.
840+
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
841+
if len(fm1.shape) == 4:
842+
fm1 = jnp.expand_dims(fm1, axis=0)
843+
fm2 = jnp.expand_dims(fm2, axis=0)
844+
fm3 = jnp.expand_dims(fm3, axis=0)
845+
fm4 = jnp.expand_dims(fm4, axis=0)
846+
847+
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
906848
out = jnp.clip(out, min=-1.0, max=1.0)
907849
self.clear_cache()
908850
if not return_dict:

0 commit comments

Comments
 (0)