@@ -181,6 +181,7 @@ def _tpu_flash_attention(
181181 flash_block_sizes : BlockSizes ,
182182 dtype : jnp .dtype = jnp .float32 ,
183183 attention_kernel : str = "flash" ,
184+ mask_padding_tokens : bool = True ,
184185) -> jax .Array :
185186 """TPU Flash Attention"""
186187
@@ -248,6 +249,8 @@ def wrap_flash_attention(query, key, value):
248249 )
249250 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
250251
252+ if not mask_padding_tokens :
253+ segment_ids = None
251254 if attention_kernel == "flash" :
252255 attention_output = vmapped_splash (query , key , value , segment_ids )
253256 else :
@@ -287,6 +290,8 @@ def ring_scan_body(carry, _):
287290 (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
288291
289292 attention_output = o_final / l_final [..., None ]
293+ else :
294+ raise ValueError ("ring attention requires fsdp > 1" )
290295
291296 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
292297
@@ -427,6 +432,7 @@ def _apply_attention(
427432 axis_names_kv : AxisNames ,
428433 flash_block_sizes : BlockSizes ,
429434 dpa_layer : Callable ,
435+ mask_padding_tokens : bool = True ,
430436):
431437 """Routes to different attention kernels."""
432438 _check_attention_inputs (query , key , value )
@@ -457,10 +463,12 @@ def _apply_attention(
457463 flash_block_sizes ,
458464 dtype ,
459465 attention_kernel ,
466+ mask_padding_tokens = mask_padding_tokens ,
460467 )
461468 elif attention_kernel == "ring" :
462469 return _tpu_flash_attention (
463- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel
470+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
471+ mask_padding_tokens = mask_padding_tokens ,
464472 )
465473 elif attention_kernel == "cudnn_flash_te" :
466474 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
@@ -591,6 +599,7 @@ def __init__(
591599 flash_block_sizes : BlockSizes = None ,
592600 dtype : DType = jnp .float32 ,
593601 quant : Quant = None ,
602+ mask_padding_tokens : bool = True ,
594603 ):
595604 self .dpa_layer = None
596605 if attention_kernel == "cudnn_flash_te" :
@@ -610,6 +619,7 @@ def __init__(
610619 self .flash_block_sizes = flash_block_sizes
611620 self .dtype = dtype
612621 self .quant = quant
622+ self .mask_padding_tokens = mask_padding_tokens
613623
614624 def apply_attention (self , query : Array , key : Array , value : Array ):
615625 return _apply_attention (
@@ -630,6 +640,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
630640 axis_names_kv = self .axis_names_kv ,
631641 flash_block_sizes = self .flash_block_sizes ,
632642 dpa_layer = self .dpa_layer ,
643+ mask_padding_tokens = self .mask_padding_tokens ,
633644 )
634645
635646
@@ -719,6 +730,7 @@ def __init__(
719730 qkv_bias : bool = False ,
720731 quant : Quant = None ,
721732 is_self_attention : bool = True ,
733+ mask_padding_tokens : bool = True ,
722734 ):
723735 if attention_kernel == "cudnn_flash_te" :
724736 raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
@@ -757,6 +769,7 @@ def __init__(
757769 flash_block_sizes = flash_block_sizes ,
758770 dtype = dtype ,
759771 quant = quant ,
772+ mask_padding_tokens = mask_padding_tokens ,
760773 )
761774 # None axes corresponds to the stacked weights across all blocks
762775 # because of the use of nnx.vmap and nnx.scan.
0 commit comments