Skip to content

Commit 04f4909

Browse files
solves distored decoded video. Now video is jittery, but frames are ok.
1 parent 34ebdbe commit 04f4909

1 file changed

Lines changed: 9 additions & 22 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __init__(
286286

287287
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
288288
# Input x: (N, D, H, W, C), assume C = self.dim
289-
n, d, h, w, c = x.shape
289+
b, t, h, w, c = x.shape
290290
assert c == self.dim
291291

292292
if self.mode == "upsample3d":
@@ -308,14 +308,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
308308
x = self.time_conv(x, feat_cache[idx])
309309
feat_cache[idx] = cache_x
310310
feat_idx[0] += 1
311-
x = x.reshape(n, 2, d, h, w, c)
312-
x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2)
313-
x = x.reshape(n, d * 2, h, w, c)
314-
d = x.shape[1]
315-
x = x.reshape(n * d, h, w, c)
311+
x = x.reshape(b, t, h, w, 2, c)
312+
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
313+
x = x.reshape(b, t * 2, h, w, c)
314+
t = x.shape[1]
315+
x = x.reshape(b * t, h, w, c)
316316
x = self.resample(x)
317317
h_new, w_new, c_new = x.shape[1:]
318-
x = x.reshape(n, d, h_new, w_new, c_new)
318+
x = x.reshape(b, t, h_new, w_new, c_new)
319319

320320
if self.mode == "downsample3d":
321321
if feat_cache is not None:
@@ -425,7 +425,6 @@ def __call__(self, x: jax.Array):
425425
x = self.norm(x)
426426

427427
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
428-
#breakpoint()
429428
#qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
430429
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
431430
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
@@ -439,21 +438,10 @@ def __call__(self, x: jax.Array):
439438
print("k min: ", np.max(k))
440439
print("v min: ", np.min(v))
441440
print("v min: ", np.max(v))
442-
#breakpoint()
443441
q = jnp.transpose(q, (0, 1, 3, 2))
444442
k = jnp.transpose(k, (0, 1, 3, 2))
445443
v = jnp.transpose(v, (0, 1, 3, 2))
446-
import torch
447-
import torch.nn.functional as F
448-
q = torch.tensor(np.array(q, dtype=np.float32))
449-
k = torch.tensor(np.array(k, dtype=np.float32))
450-
v = torch.tensor(np.array(v, dtype=np.float32))
451-
#x = jax.nn.dot_product_attention(q, k, v)
452-
x = F.scaled_dot_product_attention(q, k, v)
453-
print("attn min: ", torch.min(x))
454-
print("attn max: ", torch.max(x))
455-
#breakpoint()
456-
x = jnp.array(x.detach().numpy())
444+
x = jax.nn.dot_product_attention(q, k, v)
457445
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
458446

459447
# output projection
@@ -696,7 +684,7 @@ def __init__(
696684
upsample_mode = None
697685
if i != len(dim_mult) - 1:
698686
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
699-
# Crete and add the upsampling block
687+
# Create and add the upsampling block
700688
up_block = WanUpBlock(
701689
in_dim=in_dim,
702690
out_dim=out_dim,
@@ -731,7 +719,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
731719

732720
## middle
733721
x = self.mid_block(x, feat_cache, feat_idx)
734-
#breakpoint()
735722
## upsamples
736723
for up_block in self.up_blocks:
737724
x = up_block(x, feat_cache, feat_idx)

0 commit comments

Comments
 (0)