@@ -101,7 +101,8 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
101101 """
102102 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
103103 """
104- tensor = _unflatten_heads (tensor , heads )
104+ if tensor .ndim != 4 :
105+ tensor = _unflatten_heads (tensor , heads )
105106
106107 # pad head_dim to 128 if less than that.
107108 kv_size = tensor .shape [- 1 ]
@@ -319,12 +320,14 @@ def _apply_attention(
319320 ):
320321 """Routes to different attention kernels."""
321322 _check_attention_inputs (query , key , value )
322-
323+ seq_len_idx = 1
324+ if query .ndim == 4 :
325+ seq_len_idx = 2
323326 if attention_kernel == "flash" :
324327 can_use_flash_attention = (
325- query .shape [1 ] >= flash_min_seq_length
326- and key .shape [1 ] >= flash_min_seq_length
327- and value .shape [1 ] >= flash_min_seq_length
328+ query .shape [seq_len_idx ] >= flash_min_seq_length
329+ and key .shape [seq_len_idx ] >= flash_min_seq_length
330+ and value .shape [seq_len_idx ] >= flash_min_seq_length
328331 )
329332 else :
330333 can_use_flash_attention = True
@@ -584,7 +587,6 @@ def __init__(
584587
585588 if attention_kernel in {"flash" , "cudnn_flash_te" } and mesh is None :
586589 raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
587-
588590 self .dim_head = dim_head
589591 self .heads = heads
590592 self .inner_dim = dim_head * heads
@@ -681,7 +683,6 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
681683 xq_ = jax .lax .complex (reshape_xq [..., 0 ], reshape_xq [..., 1 ])
682684 xk_ = jax .lax .complex (reshape_xk [..., 0 ], reshape_xk [..., 1 ])
683685
684- freqs_cis = freqs_cis [None , None , ...]
685686 xq_out_complex = xq_ * freqs_cis
686687 xk_out_complex = xk_ * freqs_cis
687688
@@ -696,58 +697,26 @@ def __call__(
696697 encoder_hidden_states : jax .Array = None ,
697698 rotary_emb : Optional [jax .Array ] = None
698699 ) -> jax .Array :
699- print (" -- -- WanAttention -- " )
700700 dtype = hidden_states .dtype
701701 if encoder_hidden_states is None :
702702 encoder_hidden_states = hidden_states
703703 query_proj = self .query (hidden_states )
704- print ("query_proj min: " , np .min (query_proj ))
705- print ("query_proj max: " , np .max (query_proj ))
706704 key_proj = self .key (encoder_hidden_states )
707- print ("key_proj min: " , np .min (key_proj ))
708- print ("key_proj max: " , np .max (key_proj ))
709705 value_proj = self .value (encoder_hidden_states )
710- print ("value_proj min: " , np .min (value_proj ))
711- print ("value_proj max: " , np .max (value_proj ))
712-
713- query_proj = nn .with_logical_constraint (query_proj , self .query_axis_names )
714- key_proj = nn .with_logical_constraint (key_proj , self .key_axis_names )
715- value_proj = nn .with_logical_constraint (value_proj , self .value_axis_names )
716706
717707 if self .qk_norm :
718708 query_proj = self .norm_q (query_proj )
719709 key_proj = self .norm_k (key_proj )
720- print ("query_proj min: " , np .min (query_proj ))
721- print ("query_proj max: " , np .max (query_proj ))
722- print ("key_proj min: " , np .min (key_proj ))
723- print ("key_proj max: " , np .max (key_proj ))
724-
725710 if rotary_emb is not None :
726711 query_proj = _unflatten_heads (query_proj , self .heads )
727712 key_proj = _unflatten_heads (key_proj , self .heads )
728- # value_proj = _unflatten_heads(value_proj, self.heads)
713+ value_proj = _unflatten_heads (value_proj , self .heads )
729714 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
730- print ("Rope query_proj min: " , np .min (query_proj ))
731- print ("Rope query_proj max: " , np .max (query_proj ))
732- print ("Rope key_proj min: " , np .min (key_proj ))
733- print ("Rope key_proj max: " , np .max (key_proj ))
734- #breakpoint()
735- query_proj = _reshape_heads_to_head_dim (query_proj )
736- key_proj = _reshape_heads_to_head_dim (key_proj )
737715
738716 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
739- try :
740- print ("attn_output min: " , np .min (attn_output ))
741- print ("attn_output_for_print max: " , np .max (attn_output ))
742- except :
743- pass
744717 attn_output = attn_output .astype (dtype = dtype )
745718
746- hidden_states = self .proj_attn (hidden_states )
747- print ("hidden_states min: " , np .min (hidden_states ))
748- print ("hidden_states max: " , np .max (hidden_states ))
749- print (" -- -- WanAttention DONE -- " )
750- #breakpoint()
719+ hidden_states = self .proj_attn (attn_output )
751720 return hidden_states
752721
753722
0 commit comments