@@ -101,16 +101,28 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
101101 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
102102 """
103103 tensor = _unflatten_heads (tensor , heads )
104+
105+ # pad head_dim to 128 if less than that.
104106 kv_size = tensor .shape [- 1 ]
107+ head_dim_pad = 0
105108 if kv_size < 128 :
106- npad = ((0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , 128 - kv_size ))
107- tensor = jnp .pad (tensor , npad )
109+ head_dim_pad = 128 - kv_size
110+
111+ # pad seq_len to a multiple of flash_block_size if needed.
108112 seq_len = tensor .shape [2 ]
113+ # remainder
109114 rem = seq_len % flash_block_size
115+ seq_len_pad = 0
110116 if rem != 0 :
117+ # multiplier
111118 mul = seq_len // flash_block_size
112- npad = ((0 , 0 ), (0 , 0 ), (0 , (mul + 1 )* flash_block_size - seq_len ), (0 , 0 ))
119+ # pad to the closest multiplier of flash_block_size
120+ seq_len_pad = (mul + 1 ) * flash_block_size - seq_len
121+
122+ if kv_size < 128 or rem != 0 :
123+ npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
113124 tensor = jnp .pad (tensor , npad )
125+
114126 return tensor , kv_size , seq_len
115127
116128def _tpu_flash_attention (
@@ -140,15 +152,7 @@ def _tpu_flash_attention(
140152 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q )
141153 key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute )
142154 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute )
143- # query_seq_len = query.shape[2]
144- # query_rem = query_seq_len % block_sizes.block_q
145- # if query_rem != 0:
146- # query_mul = query_seq_len // block_sizes.block_q
147- # npad = ((0, 0), (0, 0), (0, (query_mul + 1)*block_sizes.block_q - query.shape[2]), (0, 0))
148- # query = jnp.pad(query, npad)
149- # key = jnp.pad(key, npad)
150- # value = jnp.pad(value, npad)
151- # breakpoint()
155+
152156 axis_names = nn .logical_to_mesh_axes (flash_axis_names )
153157
154158 @functools .partial (
@@ -456,7 +460,7 @@ def __init__(
456460 ):
457461 self .dpa_layer = None
458462 if attention_kernel == "cudnn_flash_te" :
459- raise NotImplementedError ("Wan 2.1 has not been tested with cudnn_flash_te " )
463+ raise NotImplementedError (f" { self } has not been tested with { attention_kernel } " )
460464
461465 self .mesh = mesh
462466 self .scale = scale
@@ -574,34 +578,13 @@ def __init__(
574578 qkv_bias : bool = False ,
575579 quant : Quant = None ,
576580 ):
577- # TODO - Params from pytorch implementation
578- # to set for the creation of this.
579- # bias is True
580- # upcast_attention - False
581- # upcast_softmax - False
582- # cross_attention_norm - None
583- # cross_attention_norm_num_groups - 32
584- # qk_norm - rms_norm_across_heads
585- # added_kv_proj_dim
586- # norm_num_groups: Optional[int] = None,
587- # spatial_norm_dim: Optional[int] = None,
588- # out_bias: bool = True,
589- # scale_qk: bool = True,
590- # only_cross_attention - False
591- # eps - 1e-06
592- # rescale_output_factor: float = 1.0,
593- # residual_connection: bool = False,
594- # _from_deprecated_attn_block: bool = False,
595- # processor: Optional["AttnProcessor"] = WanAttnProcessor2_0
596- # out_dim: int = None,
597- # out_context_dim: int = None,
598- # context_pre_only=None,
599- # pre_only=False,
600- # elementwise_affine: bool = True,
601- # is_causal: bool = False,
581+
582+ if attention_kernel == "cudnn_flash_te" or attention_kernel == "dot_product" :
583+ raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
602584
603585 if attention_kernel in {"flash" , "cudnn_flash_te" } and mesh is None :
604586 raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
587+
605588 self .dim_head = dim_head
606589 self .heads = heads
607590 self .inner_dim = dim_head * heads
@@ -717,7 +700,8 @@ def __call__(
717700 encoder_hidden_states : jax .Array ,
718701 rotary_emb : Optional [jax .Array ] = None
719702 ) -> jax .Array :
720- batch_size = hidden_states .shape [0 ]
703+ dtype = hidden_states .dtype
704+ # batch_size = hidden_states.shape[0]
721705 if encoder_hidden_states is None :
722706 encoder_hidden_states = hidden_states
723707 query_proj = self .query (hidden_states )
@@ -735,35 +719,15 @@ def __call__(
735719 key_proj = _unflatten_heads (key_proj , self .heads )
736720 if rotary_emb is not None :
737721 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
738- #breakpoint()
739722 query_proj = _reshape_heads_to_head_dim (query_proj )
740723 key_proj = _reshape_heads_to_head_dim (key_proj )
741724 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
742- breakpoint ()
743-
725+ attn_output = attn_output .astype (dtype = dtype )
744726
727+ hidden_states = self .proj_attn (hidden_states )
728+ return hidden_states
745729
746730
747- def setup (self ):
748- if self .attention_kernel in {"flash" , "cudnn_flash_te" } and self .mesh is None :
749- raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
750- inner_dim = self .dim_head * self .heads
751- scale = self .dim_head ** - 0.5
752-
753- self .attention_op = NNXAttentionOp (
754- mesh = self .mesh ,
755- attention_kernel = self .attention_kernel ,
756- scale = scale ,
757- heads = self .heads ,
758- dim_head = self .dim_head ,
759- flash_min_seq_length = self .flash_min_seq_length ,
760- use_memory_efficient_attention = self .use_memory_efficient_attention ,
761- split_head_dim = self .split_head_dim ,
762- flash_block_sizes = self .flash_block_sizes ,
763- dtype = self .dtype ,
764- float32_qk_product = False ,
765- )
766-
767731class FlaxFluxAttention (nn .Module ):
768732 query_dim : int
769733 heads : int = 8
0 commit comments