1919import jax
2020import jax .numpy as jnp
2121from flax import nnx
22+ import numpy as np
2223from .... import common_types
2324from ...modeling_flax_utils import FlaxModelMixin , get_activation
2425from ....configuration_utils import ConfigMixin , register_to_config
@@ -58,12 +59,7 @@ def __init__(
5859 use_real = False
5960 )
6061 freqs .append (freq )
61- self .freqs = jnp .concatenate (freqs , axis = 1 )
62-
63- def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
64- _ , num_frames , height , width , _ = hidden_states .shape
65- p_t , p_h , p_w = self .patch_size
66- ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
62+ freqs = jnp .concatenate (freqs , axis = 1 )
6763
6864 sizes = [
6965 self .attention_head_dim // 2 - 2 * (self .attention_head_dim // 6 ),
@@ -72,16 +68,21 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
7268 ]
7369 cumulative_sizes = jnp .cumsum (jnp .array (sizes ))
7470 split_indices = cumulative_sizes [:- 1 ]
75- freqs_split = jnp .split (self .freqs , split_indices , axis = 1 )
71+ self .freqs_split = jnp .split (freqs , split_indices , axis = 1 )
72+
73+ def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
74+ _ , num_frames , height , width , _ = hidden_states .shape
75+ p_t , p_h , p_w = self .patch_size
76+ ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
7677
77- freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
78- freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
78+ freqs_f = jnp .expand_dims (jnp .expand_dims (self . freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
79+ freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , self . freqs_split [0 ].shape [- 1 ]))
7980
80- freqs_h = jnp .expand_dims (jnp .expand_dims (freqs_split [1 ][:pph ], axis = 0 ), axis = 2 )
81- freqs_h = jnp .broadcast_to (freqs_h , (ppf , pph , ppw , freqs_split [1 ].shape [- 1 ]))
81+ freqs_h = jnp .expand_dims (jnp .expand_dims (self . freqs_split [1 ][:pph ], axis = 0 ), axis = 2 )
82+ freqs_h = jnp .broadcast_to (freqs_h , (ppf , pph , ppw , self . freqs_split [1 ].shape [- 1 ]))
8283
83- freqs_w = jnp .expand_dims (jnp .expand_dims (freqs_split [2 ][:ppw ], axis = 0 ), axis = 1 )
84- freqs_w = jnp .broadcast_to (freqs_w , (ppf , pph , ppw , freqs_split [2 ].shape [- 1 ]))
84+ freqs_w = jnp .expand_dims (jnp .expand_dims (self . freqs_split [2 ][:ppw ], axis = 0 ), axis = 1 )
85+ freqs_w = jnp .broadcast_to (freqs_w , (ppf , pph , ppw , self . freqs_split [2 ].shape [- 1 ]))
8586
8687 freqs_concat = jnp .concatenate ([freqs_f , freqs_h , freqs_w ], axis = - 1 )
8788 freqs_final = jnp .reshape (freqs_concat , (1 , 1 , ppf * pph * ppw , - 1 ))
@@ -361,22 +362,41 @@ def __call__(
361362 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
362363 (self .scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
363364 )
365+ print ("Wan Block -- START -- " )
364366
365367 # 1. Self-attention
366368 norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
369+ print ("Wan Block -- norm_hidden_states, min: " , np .min (norm_hidden_states ))
370+ print ("Wan Block -- norm_hidden_states, max: " , np .max (norm_hidden_states ))
367371 attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
372+ print ("Wan Block -- Self Attn. attn_output, min: " , np .min (attn_output ))
373+ print ("Wan Block -- Self Attn. attn_output, max: " , np .max (attn_output ))
368374 hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
375+ print ("Wan Block -- hidden_states, min: " , np .min (hidden_states ))
376+ print ("Wan Block -- hidden_states, max: " , np .max (hidden_states ))
369377
370378 # 2. Cross-attention
371379 norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 ))
380+ print ("Wan Block -- norm_hidden_states, min: " , np .min (norm_hidden_states ))
381+ print ("Wan Block -- norm_hidden_states, max: " , np .max (norm_hidden_states ))
372382 attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
383+ print ("Wan Block -- Cross Attn. attn_output, min: " , np .min (attn_output ))
384+ print ("Wan Block -- Cross Attn. attn_output, max: " , np .max (attn_output ))
373385 hidden_states = hidden_states + attn_output
386+ print ("Wan Block -- hidden_states, min: " , np .min (hidden_states ))
387+ print ("Wan Block -- hidden_states, max: " , np .max (hidden_states ))
374388
375389 # 3. Feed-forward
376390 norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (hidden_states .dtype )
377-
391+ print ("Wan Block -- norm_hidden_states, min: " , np .min (norm_hidden_states ))
392+ print ("Wan Block -- norm_hidden_states, max: " , np .max (norm_hidden_states ))
378393 ff_output = self .ffn (norm_hidden_states )
394+ print ("Wan Block -- ff_output, min: " , np .min (ff_output ))
395+ print ("Wan Block -- ff_output, max: " , np .max (ff_output ))
379396 hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (hidden_states .dtype )
397+ print ("Wan Block -- hidden_states, min: " , np .min (hidden_states ))
398+ print ("Wan Block -- hidden_states, max: " , np .max (hidden_states ))
399+ print ("Wan Block -- COMPLETE -- " )
380400 return hidden_states
381401
382402
@@ -495,19 +515,32 @@ def __call__(
495515
496516 rotary_emb = self .rope (hidden_states )
497517 hidden_states = self .patch_embedding (hidden_states )
518+ print ("***** After patch embedding" )
519+ print ("hidden_states, min: " , np .min (hidden_states ))
520+ print ("hidden_states, max: " , np .max (hidden_states ))
498521 hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
499522
500523 temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
501524 timestep , encoder_hidden_states , encoder_hidden_states_image
502525 )
526+ print ("***** After condition embedder" )
527+ print ("temb, min: " , np .min (temb ))
528+ print ("temb, max: " , np .max (temb ))
529+ print ("timestep_proj, min: " , np .min (timestep_proj ))
530+ print ("timestep_proj, max: " , np .max (timestep_proj ))
531+ print ("encoder_hidden_states min: " , np .min (encoder_hidden_states ))
532+ print ("encoder_hidden_states max: " , np .max (encoder_hidden_states ))
533+
503534 timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
504535
505536 if encoder_hidden_states_image is not None :
506537 raise NotImplementedError ("img2vid is not yet implemented." )
507538
508539 for block in self .blocks :
509540 hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
510-
541+ print ("After block, hidden_states min:" , np .min (hidden_states ))
542+ print ("After block, hidden_states max:" , np .max (hidden_states ))
543+ #breakpoint()
511544 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
512545
513546 hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
0 commit comments