@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174 flash_block_sizes : BlockSizes ,
175175 dtype : jnp .dtype = jnp .float32 ,
176176 attention_kernel : str = "flash" ,
177+ residual_checkpoint_name : str | None = None ,
177178) -> jax .Array :
178179 """TPU Flash Attention"""
179180
@@ -213,9 +214,22 @@ def _tpu_flash_attention(
213214 )
214215 def wrap_flash_attention (query , key , value ):
215216
216- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_sizes .block_q )
217- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_sizes .block_kv )
218- value , _ , _ = _pad_data_for_flash (value , heads , block_sizes .block_kv )
217+ uses_fused_kernel = block_sizes .use_fused_bwd_kernel
218+ block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv ,)
219+ block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv ,)
220+ if uses_fused_kernel :
221+ block_q_sizes += (block_sizes .block_q_dkv ,)
222+ block_kv_sizes += (block_sizes .block_kv_dkv ,)
223+ else :
224+ block_q_sizes += (block_sizes .block_q_dq ,)
225+ block_kv_sizes += (block_sizes .block_kv_dq ,)
226+
227+ block_q = max (* block_q_sizes )
228+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
229+
230+ block_kv = max (* block_kv_sizes )
231+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
232+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
219233
220234 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
221235 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
@@ -237,6 +251,7 @@ def wrap_flash_attention(query, key, value):
237251 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
238252 block_sizes = block_sizes ,
239253 save_residuals = True if attention_kernel == "ring" else False ,
254+ residual_checkpoint_name = residual_checkpoint_name ,
240255 )
241256 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
242257
@@ -419,6 +434,7 @@ def _apply_attention(
419434 axis_names_kv : AxisNames ,
420435 flash_block_sizes : BlockSizes ,
421436 dpa_layer : Callable ,
437+ residual_checkpoint_name : str | None = None ,
422438):
423439 """Routes to different attention kernels."""
424440 _check_attention_inputs (query , key , value )
@@ -439,7 +455,7 @@ def _apply_attention(
439455 )
440456 elif attention_kernel == "flash" :
441457 return _tpu_flash_attention (
442- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
458+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , residual_checkpoint_name = residual_checkpoint_name
443459 )
444460 elif attention_kernel == "ring" :
445461 return _tpu_flash_attention (
@@ -574,6 +590,7 @@ def __init__(
574590 flash_block_sizes : BlockSizes = None ,
575591 dtype : DType = jnp .float32 ,
576592 quant : Quant = None ,
593+ residual_checkpoint_name : str | None = None ,
577594 ):
578595 self .dpa_layer = None
579596 if attention_kernel == "cudnn_flash_te" :
@@ -593,6 +610,7 @@ def __init__(
593610 self .flash_block_sizes = flash_block_sizes
594611 self .dtype = dtype
595612 self .quant = quant
613+ self .residual_checkpoint_name = residual_checkpoint_name
596614
597615 def apply_attention (self , query : Array , key : Array , value : Array ):
598616 return _apply_attention (
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613631 axis_names_kv = self .axis_names_kv ,
614632 flash_block_sizes = self .flash_block_sizes ,
615633 dpa_layer = self .dpa_layer ,
634+ residual_checkpoint_name = self .residual_checkpoint_name ,
616635 )
617636
618637
@@ -701,6 +720,7 @@ def __init__(
701720 precision : jax .lax .Precision = None ,
702721 qkv_bias : bool = False ,
703722 quant : Quant = None ,
723+ residual_checkpoint_name : str | None = None ,
704724 ):
705725 if attention_kernel == "cudnn_flash_te" :
706726 raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
@@ -730,6 +750,7 @@ def __init__(
730750 flash_block_sizes = flash_block_sizes ,
731751 dtype = dtype ,
732752 quant = quant ,
753+ residual_checkpoint_name = residual_checkpoint_name ,
733754 )
734755 # None axes corresponds to the stacked weights across all blocks
735756 # because of the use of nnx.vmap and nnx.scan.
0 commit comments