@@ -94,50 +94,31 @@ def __init__(
9494 )
9595
9696 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
97- print ("wanCausalConv3d, x min: " , np .min (x ))
98- print ("wanCausalConv3d, x max: " , np .max (x ))
9997 current_padding = list (self ._causal_padding ) # Mutable copy
10098 padding_needed = self ._depth_padding_before
10199
102100 if cache_x is not None and padding_needed > 0 :
103- print ("WanCausalConv3d, cache.shape: " , cache_x .shape )
104- print ("wanCausalConv3d, cache_x min: " , np .min (cache_x ))
105- print ("wanCausalConv3d, cache_x max: " , np .max (cache_x ))
106101 # Ensure cache has same spatial/channel dims, potentially different depth
107102 assert cache_x .shape [0 ] == x .shape [0 ] and cache_x .shape [2 :] == x .shape [2 :], "Cache spatial/channel dims mismatch"
108103 cache_len = cache_x .shape [1 ]
109104 x = jnp .concatenate ([cache_x , x ], axis = 1 ) # Concat along depth (D)
110105
111106 padding_needed -= cache_len
112107 if padding_needed < 0 :
113- print ("wanCausanConv3d, padding_needed < 0" )
114108 # Cache longer than needed padding, trim from start
115109 x = x [:, - padding_needed :, ...]
116110 current_padding [1 ] = (0 , 0 ) # No explicit padding needed now
117111 else :
118112 # Update depth padding needed
119- print ("wanCausanConv3d, padding_needed > 0" )
120113 current_padding [1 ] = (padding_needed , 0 )
121114
122115 # Apply padding if any dimension requires it
123116 padding_to_apply = tuple (current_padding )
124- print ("WanCausalConv3d, before padding x shape: " , x .shape )
125117 if any (p > 0 for dim_pads in padding_to_apply for p in dim_pads ):
126- print ("WanCausalConv3d, applying padding" )
127118 x_padded = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
128119 else :
129- print ("WanCausalConv3d, NOT applying padding" )
130120 x_padded = x
131-
132- print ("WanCausalConv3d, x shape: " , x_padded .shape )
133- print ("wanCausalConv3d, x min: " , np .min (x_padded ))
134- print ("wanCausalConv3d, x max: " , np .max (x_padded ))
135- # if idx == 12:
136- # breakpoint()
137121 out = self .conv (x_padded )
138- print ("WanCausalConv3d, after conv, x shape: " , out .shape )
139- print ("wanCausalConv3d, x min: " , np .min (out ))
140- print ("wanCausalConv3d, x max: " , np .max (out ))
141122 return out
142123
143124
@@ -300,8 +281,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
300281 if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
301282 # cache last frame of last two chunk
302283 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
303- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] ! = "Rep" :
304- cache_x = jnp .concatenate ([jnp .zeros (cache_x .shape ), cache_x ], dim = 1 )
284+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] = = "Rep" :
285+ cache_x = jnp .concatenate ([jnp .zeros (cache_x .shape ), cache_x ], axis = 1 )
305286 if feat_cache [idx ] == "Rep" :
306287 x = self .time_conv (x )
307288 else :
@@ -364,13 +345,10 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
364345
365346 if feat_cache is not None :
366347 idx = feat_idx [0 ]
367- print ("Before conv1, idx: " , idx )
368348 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
369349 if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
370350 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
371351 x = self .conv1 (x , feat_cache [idx ], idx )
372- # if idx == 4:
373- # breakpoint()
374352 feat_cache [idx ] = cache_x
375353 feat_idx [0 ] += 1
376354 else :
@@ -379,32 +357,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
379357 x = self .norm2 (x )
380358 x = self .nonlinearity (x )
381359 idx = feat_idx [0 ]
382- # if idx == 4:
383- # breakpoint()
384360
385361 if feat_cache is not None :
386362 idx = feat_idx [0 ]
387- print ("Residual block, idx: " , idx )
388- # if idx == 14:
389- # breakpoint()
390- print ("cache_x min: " , np .min (cache_x ))
391- print ("cache_x max: " , np .max (cache_x ))
392363 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
393364 if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
394365 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
395- print ("cache_x min: " , np .min (cache_x ))
396- print ("cache_x max: " , np .max (cache_x ))
397- #breakpoint()
398366 x = self .conv2 (x , feat_cache [idx ])
399367 feat_cache [idx ] = cache_x
400368 feat_idx [0 ] += 1
401369 else :
402370 x = self .conv2 (x )
403- print ("before conv shortcut add: x min" , np .min (x ))
404- print ("before conv shortcut add: x max" , np .max (x ))
405371 x = x + h
406- print ("after conv shortcut add: x min: " , np .min (x ))
407- print ("after conv shortcut add: x max: " , np .max (x ))
408372 return x
409373
410374
@@ -428,16 +392,8 @@ def __call__(self, x: jax.Array):
428392 #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
429393 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
430394 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
431- print ("qkv min: " , np .min (qkv ))
432- print ("qkv max: " , np .max (qkv ))
433395 #q, k, v = jnp.split(qkv, 3, axis=-1)
434396 q , k , v = jnp .split (qkv , 3 , axis = - 2 )
435- print ("q min: " , np .min (q ))
436- print ("q max: " , np .max (q ))
437- print ("k min: " , np .min (k ))
438- print ("k min: " , np .max (k ))
439- print ("v min: " , np .min (v ))
440- print ("v min: " , np .max (v ))
441397 q = jnp .transpose (q , (0 , 1 , 3 , 2 ))
442398 k = jnp .transpose (k , (0 , 1 , 3 , 2 ))
443399 v = jnp .transpose (v , (0 , 1 , 3 , 2 ))
@@ -446,10 +402,8 @@ def __call__(self, x: jax.Array):
446402
447403 # output projection
448404 x = self .proj (x )
449- #breakpoint()
450405 # Reshape back
451406 x = x .reshape (batch_size , time , height , width , channels )
452- #breakpoint()
453407
454408 return x + identity
455409
@@ -467,20 +421,11 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity
467421 self .resnets = resnets
468422
469423 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
470- print ("WanMidblock..." )
471424 x = self .resnets [0 ](x , feat_cache , feat_idx )
472- print ("WanMidBlock resnets[0], x min: " , np .min (x ))
473- print ("WanMidBlock resnets[0], x max: " , np .max (x ))
474425 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
475- print ("WanMidBlock, for loop, attn len: " , len (self .attentions ))
476- print ("WanMidBlock, for loop, resnets len: " , len (self .resnets ))
477426 if attn is not None :
478427 x = attn (x )
479- print ("WanMidBlock attn[0], x min: " , np .min (x ))
480- print ("WanMidBlock attn[0], x max: " , np .max (x ))
481428 x = resnet (x , feat_cache , feat_idx )
482- print ("WanMidBlock resnets[i], x min: " , np .min (x ))
483- print ("WanMidBlock resnets[i], x max: " , np .max (x ))
484429 return x
485430
486431
@@ -888,21 +833,18 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
888833 out = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
889834 else :
890835 out_ = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
891- out = jnp .concatenate ([out , out_ ], axis = 1 )
892- print ("out_.shape: " , out_ .shape )
893- print ("out_ min: " , np .min (out_ ))
894- print ("out_ max: " , np .max (out_ ))
895- print ("out.shape: " , out .shape )
896- print ("out min: " , np .min (out ))
897- print ("out max: " , np .max (out ))
898- for i in range (len (self ._feat_map )):
899- if isinstance (self ._feat_map [i ], jax .Array ):
900- print ("i: " , i )
901- print ("min: " , np .min (self ._feat_map [i ]))
902- print ("max: " , np .max (self ._feat_map [i ]))
903- else :
904- print (f"feat_map[{ i } ] : { self ._feat_map [i ]} " )
905- # breakpoint()
836+
837+ # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
838+ # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
839+ # Most likely due to an incorrect reshaping in the decoder.
840+ fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
841+ if len (fm1 .shape ) == 4 :
842+ fm1 = jnp .expand_dims (fm1 , axis = 0 )
843+ fm2 = jnp .expand_dims (fm2 , axis = 0 )
844+ fm3 = jnp .expand_dims (fm3 , axis = 0 )
845+ fm4 = jnp .expand_dims (fm4 , axis = 0 )
846+
847+ out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
906848 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
907849 self .clear_cache ()
908850 if not return_dict :
0 commit comments