2323from ..modeling_flax_utils import FlaxModelMixin
2424from ... import common_types
2525from ..vae_flax import (FlaxAutoencoderKLOutput , FlaxDiagonalGaussianDistribution , FlaxDecoderOutput )
26-
26+ import numpy as np
2727BlockSizes = common_types .BlockSizes
2828
2929CACHE_T = 2
@@ -93,33 +93,51 @@ def __init__(
9393 rngs = rngs ,
9494 )
9595
96- def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None ) -> jax .Array :
96+ def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
97+ print ("wanCausalConv3d, x min: " , np .min (x ))
98+ print ("wanCausalConv3d, x max: " , np .max (x ))
9799 current_padding = list (self ._causal_padding ) # Mutable copy
98100 padding_needed = self ._depth_padding_before
99101
100102 if cache_x is not None and padding_needed > 0 :
103+ print ("WanCausalConv3d, cache.shape: " , cache_x .shape )
104+ print ("wanCausalConv3d, cache_x min: " , np .min (cache_x ))
105+ print ("wanCausalConv3d, cache_x max: " , np .max (cache_x ))
101106 # Ensure cache has same spatial/channel dims, potentially different depth
102107 assert cache_x .shape [0 ] == x .shape [0 ] and cache_x .shape [2 :] == x .shape [2 :], "Cache spatial/channel dims mismatch"
103108 cache_len = cache_x .shape [1 ]
104109 x = jnp .concatenate ([cache_x , x ], axis = 1 ) # Concat along depth (D)
105110
106111 padding_needed -= cache_len
107112 if padding_needed < 0 :
113+ print ("wanCausanConv3d, padding_needed < 0" )
108114 # Cache longer than needed padding, trim from start
109115 x = x [:, - padding_needed :, ...]
110116 current_padding [1 ] = (0 , 0 ) # No explicit padding needed now
111117 else :
112118 # Update depth padding needed
119+ print ("wanCausanConv3d, padding_needed > 0" )
113120 current_padding [1 ] = (padding_needed , 0 )
114121
115122 # Apply padding if any dimension requires it
116123 padding_to_apply = tuple (current_padding )
124+ print ("WanCausalConv3d, before padding x shape: " , x .shape )
117125 if any (p > 0 for dim_pads in padding_to_apply for p in dim_pads ):
126+ print ("WanCausalConv3d, applying padding" )
118127 x_padded = jnp .pad (x , padding_to_apply , mode = "constant" , constant_values = 0.0 )
119128 else :
129+ print ("WanCausalConv3d, NOT applying padding" )
120130 x_padded = x
121131
132+ print ("WanCausalConv3d, x shape: " , x_padded .shape )
133+ print ("wanCausalConv3d, x min: " , np .min (x_padded ))
134+ print ("wanCausalConv3d, x max: " , np .max (x_padded ))
135+ # if idx == 12:
136+ # breakpoint()
122137 out = self .conv (x_padded )
138+ print ("WanCausalConv3d, after conv, x shape: " , out .shape )
139+ print ("wanCausalConv3d, x min: " , np .min (out ))
140+ print ("wanCausalConv3d, x max: " , np .max (out ))
123141 return out
124142
125143
@@ -346,31 +364,48 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
346364
347365 if feat_cache is not None :
348366 idx = feat_idx [0 ]
367+ print ("Before conv1, idx: " , idx )
349368 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
350369 if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
351370 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
352-
353- x = self .conv1 (x , feat_cache [idx ])
371+ x = self .conv1 (x , feat_cache [idx ], idx )
372+ # if idx == 4:
373+ # breakpoint()
354374 feat_cache [idx ] = cache_x
355375 feat_idx [0 ] += 1
356376 else :
357377 x = self .conv1 (x )
358378
359379 x = self .norm2 (x )
360380 x = self .nonlinearity (x )
381+ idx = feat_idx [0 ]
382+ # if idx == 4:
383+ # breakpoint()
361384
362385 if feat_cache is not None :
363386 idx = feat_idx [0 ]
387+ print ("Residual block, idx: " , idx )
388+ # if idx == 14:
389+ # breakpoint()
390+ print ("cache_x min: " , np .min (cache_x ))
391+ print ("cache_x max: " , np .max (cache_x ))
364392 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
365393 if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
366394 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
395+ print ("cache_x min: " , np .min (cache_x ))
396+ print ("cache_x max: " , np .max (cache_x ))
397+ #breakpoint()
367398 x = self .conv2 (x , feat_cache [idx ])
368399 feat_cache [idx ] = cache_x
369400 feat_idx [0 ] += 1
370401 else :
371402 x = self .conv2 (x )
372-
373- return x + h
403+ print ("before conv shortcut add: x min" , np .min (x ))
404+ print ("before conv shortcut add: x max" , np .max (x ))
405+ x = x + h
406+ print ("after conv shortcut add: x min: " , np .min (x ))
407+ print ("after conv shortcut add: x max: " , np .max (x ))
408+ return x
374409
375410
376411class WanAttentionBlock (nnx .Module ):
@@ -382,26 +417,51 @@ def __init__(self, dim: int, rngs: nnx.Rngs):
382417 self .proj = nnx .Conv (in_features = dim , out_features = dim , kernel_size = (1 , 1 ), rngs = rngs )
383418
384419 def __call__ (self , x : jax .Array ):
385- batch_size , time , height , width , channels = x . shape
420+
386421 identity = x
422+ batch_size , time , height , width , channels = x .shape
387423
388424 x = x .reshape (batch_size * time , height , width , channels )
389425 x = self .norm (x )
390426
391427 qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
392-
393- qkv = qkv .reshape (batch_size * time , 1 , channels * 3 , - 1 )
428+ #breakpoint()
429+ #qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
430+ qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
394431 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
395- q , k , v = jnp .split (qkv , 3 , axis = - 1 )
396-
397- x = jax .nn .dot_product_attention (q , k , v )
432+ print ("qkv min: " , np .min (qkv ))
433+ print ("qkv max: " , np .max (qkv ))
434+ #q, k, v = jnp.split(qkv, 3, axis=-1)
435+ q , k , v = jnp .split (qkv , 3 , axis = - 2 )
436+ print ("q min: " , np .min (q ))
437+ print ("q max: " , np .max (q ))
438+ print ("k min: " , np .min (k ))
439+ print ("k min: " , np .max (k ))
440+ print ("v min: " , np .min (v ))
441+ print ("v min: " , np .max (v ))
442+ #breakpoint()
443+ q = jnp .transpose (q , (0 , 1 , 3 , 2 ))
444+ k = jnp .transpose (k , (0 , 1 , 3 , 2 ))
445+ v = jnp .transpose (v , (0 , 1 , 3 , 2 ))
446+ import torch
447+ import torch .nn .functional as F
448+ q = torch .tensor (np .array (q , dtype = np .float32 ))
449+ k = torch .tensor (np .array (k , dtype = np .float32 ))
450+ v = torch .tensor (np .array (v , dtype = np .float32 ))
451+ #x = jax.nn.dot_product_attention(q, k, v)
452+ x = F .scaled_dot_product_attention (q , k , v )
453+ print ("attn min: " , torch .min (x ))
454+ print ("attn max: " , torch .max (x ))
455+ #breakpoint()
456+ x = jnp .array (x .detach ().numpy ())
398457 x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
399458
400459 # output projection
401460 x = self .proj (x )
402-
461+ #breakpoint()
403462 # Reshape back
404463 x = x .reshape (batch_size , time , height , width , channels )
464+ #breakpoint()
405465
406466 return x + identity
407467
@@ -419,11 +479,20 @@ def __init__(self, dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity
419479 self .resnets = resnets
420480
421481 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
482+ print ("WanMidblock..." )
422483 x = self .resnets [0 ](x , feat_cache , feat_idx )
484+ print ("WanMidBlock resnets[0], x min: " , np .min (x ))
485+ print ("WanMidBlock resnets[0], x max: " , np .max (x ))
423486 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
487+ print ("WanMidBlock, for loop, attn len: " , len (self .attentions ))
488+ print ("WanMidBlock, for loop, resnets len: " , len (self .resnets ))
424489 if attn is not None :
425490 x = attn (x )
491+ print ("WanMidBlock attn[0], x min: " , np .min (x ))
492+ print ("WanMidBlock attn[0], x max: " , np .max (x ))
426493 x = resnet (x , feat_cache , feat_idx )
494+ print ("WanMidBlock resnets[i], x min: " , np .min (x ))
495+ print ("WanMidBlock resnets[i], x max: " , np .max (x ))
427496 return x
428497
429498
@@ -589,7 +658,7 @@ def __init__(
589658 self ,
590659 rngs : nnx .Rngs ,
591660 dim : int = 128 ,
592- z_dim : int = 128 ,
661+ z_dim : int = 4 ,
593662 dim_mult : List [int ] = [1 , 2 , 4 , 4 ],
594663 num_res_blocks : int = 2 ,
595664 attn_scales = List [float ],
@@ -662,7 +731,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
662731
663732 ## middle
664733 x = self .mid_block (x , feat_cache , feat_idx )
665-
734+ #breakpoint()
666735 ## upsamples
667736 for up_block in self .up_blocks :
668737 x = up_block (x , feat_cache , feat_idx )
@@ -810,7 +879,6 @@ def _encode(self, x: jax.Array):
810879 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
811880 enc = jnp .concatenate ([mu , logvar ], axis = - 1 )
812881 self .clear_cache ()
813- # return enc
814882 return enc
815883
816884 def encode (
@@ -833,10 +901,22 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
833901 out = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
834902 else :
835903 out_ = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
836-
837904 out = jnp .concatenate ([out , out_ ], axis = 1 )
838-
839- out = jnp .clip (out , a_min = - 1.0 , a_max = 1.0 )
905+ print ("out_.shape: " , out_ .shape )
906+ print ("out_ min: " , np .min (out_ ))
907+ print ("out_ max: " , np .max (out_ ))
908+ print ("out.shape: " , out .shape )
909+ print ("out min: " , np .min (out ))
910+ print ("out max: " , np .max (out ))
911+ for i in range (len (self ._feat_map )):
912+ if isinstance (self ._feat_map [i ], jax .Array ):
913+ print ("i: " , i )
914+ print ("min: " , np .min (self ._feat_map [i ]))
915+ print ("max: " , np .max (self ._feat_map [i ]))
916+ else :
917+ print (f"feat_map[{ i } ] : { self ._feat_map [i ]} " )
918+ # breakpoint()
919+ out = jnp .clip (out , min = - 1.0 , max = 1.0 )
840920 self .clear_cache ()
841921 if not return_dict :
842922 return (out ,)
0 commit comments