@@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config(
272272 attn_logits_soft_cap : float | None = None ,
273273 fuse_reciprocal : bool = True ,
274274 use_base2_exp : bool = False ,
275+ use_experimental_scheduler : bool = False ,
275276 max_logit_const : float | None = None ,
276277 interpret : bool = False ,
277278 dq_reduction_steps : int | None = None ,
@@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config(
294295 attn_logits_soft_cap = attn_logits_soft_cap ,
295296 fuse_reciprocal = fuse_reciprocal ,
296297 use_base2_exp = use_base2_exp ,
298+ use_experimental_scheduler = use_experimental_scheduler ,
297299 max_logit_const = max_logit_const ,
298300 interpret = interpret ,
299301 dq_reduction_steps = dq_reduction_steps ,
@@ -314,6 +316,8 @@ def _tpu_flash_attention(
314316 mask_padding_tokens : bool = True ,
315317 residual_checkpoint_name : str | None = None ,
316318 attention_mask : jax .Array = None ,
319+ use_base2_exp : bool = False ,
320+ use_experimental_scheduler : bool = False ,
317321) -> jax .Array :
318322 """TPU Flash Attention"""
319323
@@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value):
399403 splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
400404 mask = mask ,
401405 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
402- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
406+ config = convert_to_tokamax_splash_config (
407+ block_sizes ,
408+ residual_checkpoint_name = residual_checkpoint_name ,
409+ use_base2_exp = use_base2_exp ,
410+ use_experimental_scheduler = use_experimental_scheduler ,
411+ ),
403412 save_residuals = False ,
404413 )
405414 elif attention_kernel == "tokamax_ring" :
@@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value):
409418 splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
410419 mask = mask ,
411420 is_mqa = False ,
412- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
421+ config = convert_to_tokamax_splash_config (
422+ block_sizes ,
423+ residual_checkpoint_name = residual_checkpoint_name ,
424+ use_base2_exp = use_base2_exp ,
425+ use_experimental_scheduler = use_experimental_scheduler ,
426+ ),
413427 save_residuals = False ,
414428 ring_axis = "context" ,
415429 rotate_segment_ids = False , # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
@@ -741,6 +755,8 @@ def _apply_attention(
741755 mask_padding_tokens : bool = True ,
742756 residual_checkpoint_name : str | None = None ,
743757 attention_mask : Array = None ,
758+ use_base2_exp : bool = False ,
759+ use_experimental_scheduler : bool = False ,
744760):
745761 """Routes to different attention kernels."""
746762 _check_attention_inputs (query , key , value )
@@ -789,6 +805,8 @@ def _apply_attention(
789805 mask_padding_tokens = mask_padding_tokens ,
790806 residual_checkpoint_name = residual_checkpoint_name ,
791807 attention_mask = attention_mask ,
808+ use_base2_exp = use_base2_exp ,
809+ use_experimental_scheduler = use_experimental_scheduler ,
792810 )
793811 elif "ring" in attention_kernel :
794812 return _tpu_flash_attention (
@@ -983,8 +1001,12 @@ def __init__(
9831001 quant : Quant = None ,
9841002 mask_padding_tokens : bool = True ,
9851003 residual_checkpoint_name : str | None = None ,
1004+ use_base2_exp : bool = False ,
1005+ use_experimental_scheduler : bool = False ,
9861006 ):
9871007 self .dpa_layer = None
1008+ self .use_base2_exp = use_base2_exp
1009+ self .use_experimental_scheduler = use_experimental_scheduler
9881010 if attention_kernel == "cudnn_flash_te" :
9891011 from transformer_engine .jax .flax .transformer import DotProductAttention # pytype: disable=import-error
9901012
@@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
10451067 mask_padding_tokens = self .mask_padding_tokens ,
10461068 residual_checkpoint_name = self .residual_checkpoint_name ,
10471069 attention_mask = attention_mask ,
1070+ use_base2_exp = self .use_base2_exp if hasattr (self , "use_base2_exp" ) else False ,
1071+ use_experimental_scheduler = self .use_experimental_scheduler if hasattr (self , "use_experimental_scheduler" ) else False ,
10481072 )
10491073
10501074
@@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module):
10631087 flash_block_sizes : BlockSizes = None
10641088 dtype : DType = jnp .float32
10651089 quant : Quant = None
1090+ use_base2_exp : bool = False
1091+ use_experimental_scheduler : bool = False
10661092
10671093 def setup (self ):
10681094 self .dpa_layer = None
@@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
11081134 flash_block_sizes = self .flash_block_sizes ,
11091135 dpa_layer = self .dpa_layer ,
11101136 attention_mask = attention_mask ,
1137+ use_base2_exp = self .use_base2_exp ,
1138+ use_experimental_scheduler = self .use_experimental_scheduler ,
11111139 )
11121140
11131141
@@ -1144,6 +1172,8 @@ def __init__(
11441172 enable_jax_named_scopes : bool = False ,
11451173 added_kv_proj_dim : Optional [int ] = None , # New for I2V
11461174 image_seq_len : Optional [int ] = None , # New for I2V
1175+ use_base2_exp : bool = False ,
1176+ use_experimental_scheduler : bool = False ,
11471177 ):
11481178 if attention_kernel in {"flash" , "cudnn_flash_te" } and mesh is None :
11491179 raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
@@ -1186,6 +1216,8 @@ def __init__(
11861216 quant = quant ,
11871217 mask_padding_tokens = mask_padding_tokens ,
11881218 residual_checkpoint_name = residual_checkpoint_name ,
1219+ use_base2_exp = use_base2_exp ,
1220+ use_experimental_scheduler = use_experimental_scheduler ,
11891221 )
11901222 # None axes corresponds to the stacked weights across all blocks
11911223 # because of the use of nnx.vmap and nnx.scan.
0 commit comments