@@ -286,7 +286,7 @@ def __init__(
286286
287287 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]) -> jax .Array :
288288 # Input x: (N, D, H, W, C), assume C = self.dim
289- n , d , h , w , c = x .shape
289+ b , t , h , w , c = x .shape
290290 assert c == self .dim
291291
292292 if self .mode == "upsample3d" :
@@ -308,14 +308,14 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
308308 x = self .time_conv (x , feat_cache [idx ])
309309 feat_cache [idx ] = cache_x
310310 feat_idx [0 ] += 1
311- x = x .reshape (n , 2 , d , h , w , c )
312- x = jnp .stack ([x [:, 0 , :, :, : , :], x [:, 1 , :, :, : , :]], axis = 2 )
313- x = x .reshape (n , d * 2 , h , w , c )
314- d = x .shape [1 ]
315- x = x .reshape (n * d , h , w , c )
311+ x = x .reshape (b , t , h , w , 2 , c )
312+ x = jnp .stack ([x [:, : , :, :, 0 , :], x [:, : , :, :, 1 , :]], axis = 1 )
313+ x = x .reshape (b , t * 2 , h , w , c )
314+ t = x .shape [1 ]
315+ x = x .reshape (b * t , h , w , c )
316316 x = self .resample (x )
317317 h_new , w_new , c_new = x .shape [1 :]
318- x = x .reshape (n , d , h_new , w_new , c_new )
318+ x = x .reshape (b , t , h_new , w_new , c_new )
319319
320320 if self .mode == "downsample3d" :
321321 if feat_cache is not None :
@@ -425,7 +425,6 @@ def __call__(self, x: jax.Array):
425425 x = self .norm (x )
426426
427427 qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
428- #breakpoint()
429428 #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
430429 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
431430 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
@@ -439,21 +438,10 @@ def __call__(self, x: jax.Array):
439438 print ("k min: " , np .max (k ))
440439 print ("v min: " , np .min (v ))
441440 print ("v min: " , np .max (v ))
442- #breakpoint()
443441 q = jnp .transpose (q , (0 , 1 , 3 , 2 ))
444442 k = jnp .transpose (k , (0 , 1 , 3 , 2 ))
445443 v = jnp .transpose (v , (0 , 1 , 3 , 2 ))
446- import torch
447- import torch .nn .functional as F
448- q = torch .tensor (np .array (q , dtype = np .float32 ))
449- k = torch .tensor (np .array (k , dtype = np .float32 ))
450- v = torch .tensor (np .array (v , dtype = np .float32 ))
451- #x = jax.nn.dot_product_attention(q, k, v)
452- x = F .scaled_dot_product_attention (q , k , v )
453- print ("attn min: " , torch .min (x ))
454- print ("attn max: " , torch .max (x ))
455- #breakpoint()
456- x = jnp .array (x .detach ().numpy ())
444+ x = jax .nn .dot_product_attention (q , k , v )
457445 x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
458446
459447 # output projection
@@ -696,7 +684,7 @@ def __init__(
696684 upsample_mode = None
697685 if i != len (dim_mult ) - 1 :
698686 upsample_mode = "upsample3d" if temperal_upsample [i ] else "upsample2d"
699- # Crete and add the upsampling block
687+ # Create and add the upsampling block
700688 up_block = WanUpBlock (
701689 in_dim = in_dim ,
702690 out_dim = out_dim ,
@@ -731,7 +719,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
731719
732720 ## middle
733721 x = self .mid_block (x , feat_cache , feat_idx )
734- #breakpoint()
735722 ## upsamples
736723 for up_block in self .up_blocks :
737724 x = up_block (x , feat_cache , feat_idx )
0 commit comments