@@ -82,6 +82,7 @@ class Config:
8282 qk_nope_head_dim : int = 128
8383 qk_rope_head_dim : int = 64
8484 v_head_dim : int = 128
85+ use_tokamax_splash : bool = True
8586 # yarn
8687 rope_type : str = "yarn"
8788 original_max_position_embeddings : int = 4096
@@ -98,7 +99,6 @@ class Config:
9899 use_sparse_indexer : bool = True
99100 index_n_heads : int = 64
100101 index_head_dim : int = 128 # > qk_rope_head_dim
101- index_topk : int = 4
102102
103103
104104class ModelArgs :
@@ -107,7 +107,7 @@ class ModelArgs:
107107 Maps MaxText Config keys to the specific variable names expected by the reference implementation.
108108 """
109109
110- def __init__ (self , config : Config , max_batch_size : int = 8 ):
110+ def __init__ (self , config : Config , max_batch_size : int = 8 , index_topk : int = 4 ):
111111 self .max_batch_size = max_batch_size
112112 self .scale_fmt = None
113113 self .max_seq_len = config .max_position_embeddings
@@ -119,6 +119,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
119119 self .qk_nope_head_dim = config .qk_nope_head_dim
120120 self .qk_rope_head_dim = config .qk_rope_head_dim
121121 self .v_head_dim = config .v_head_dim
122+ self .use_tokamax_splash = config .use_tokamax_splash
122123 # yarn
123124 self .original_seq_len = config .original_max_position_embeddings
124125 self .rope_theta = float (config .rope_max_timescale )
@@ -129,7 +130,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
129130 # indexer
130131 self .index_n_heads = config .index_n_heads
131132 self .index_head_dim = config .index_head_dim
132- self .index_topk = config . index_topk
133+ self .index_topk = index_topk
133134
134135
135136# -----------------------------------------------------------------------------
@@ -457,14 +458,14 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
457458
458459class Indexer (torch .nn .Module ): # pylint: disable=missing-class-docstring
459460
460- def __init__ (self , args : ModelArgs ):
461+ def __init__ (self , args : ModelArgs , index_topk : int = 4 ):
461462 super ().__init__ ()
462463 self .dim : int = args .dim
463464 self .n_heads : int = args .index_n_heads
464465 self .n_local_heads = args .index_n_heads // world_size
465466 self .head_dim : int = args .index_head_dim
466467 self .rope_head_dim : int = args .qk_rope_head_dim
467- self .index_topk : int = args . index_topk
468+ self .index_topk : int = index_topk
468469 self .q_lora_rank : int = args .q_lora_rank
469470 self .wq_b = Linear (self .q_lora_rank , self .n_heads * self .head_dim )
470471 self .wk = Linear (self .dim , self .head_dim )
@@ -580,7 +581,7 @@ class MLA(nn.Module):
580581 softmax_scale (float): Scaling factor for softmax in attention computation.
581582 """
582583
583- def __init__ (self , args : ModelArgs ):
584+ def __init__ (self , args : ModelArgs , index_topk : int ):
584585 super ().__init__ ()
585586 self .dim = args .dim
586587 self .n_heads = args .n_heads
@@ -605,7 +606,7 @@ def __init__(self, args: ModelArgs):
605606 mscale = 0.1 * args .mscale * math .log (args .rope_factor ) + 1.0
606607 self .softmax_scale = self .softmax_scale * mscale * mscale
607608
608- self .indexer = Indexer (args )
609+ self .indexer = Indexer (args , index_topk )
609610
610611 self .register_buffer (
611612 "kv_cache" , torch .zeros (args .max_batch_size , args .max_seq_len , self .kv_lora_rank ), persistent = False
@@ -750,7 +751,7 @@ def get_jax_mla_weights(pt_mla, cfg):
750751 }
751752
752753
753- def get_cfg_and_mesh (config , run_name , dtype , batch_size , seq_len ):
754+ def get_cfg_and_mesh (config , run_name , dtype , batch_size , seq_len , attention , index_topk ):
754755 """Returns MaxText configuration and mesh."""
755756 cfg = pyconfig .initialize (
756757 [None , get_test_config_path ()],
@@ -766,7 +767,8 @@ def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len):
766767 per_device_batch_size = batch_size ,
767768 max_target_length = seq_len ,
768769 max_prefill_predict_length = seq_len ,
769- attention = "dot_product" ,
770+ attention = attention ,
771+ index_topk = index_topk ,
770772 ** asdict (config ),
771773 )
772774 devices_array = maxtext_utils .create_device_mesh (cfg )
@@ -785,7 +787,7 @@ def setUp(self):
785787 np .random .seed (42 )
786788
787789 self .dtype = "float32"
788- self .batch_size = 2
790+ self .batch_size = 4
789791 self .start_pos = 0
790792 self .nnx_rng = nnx .Rngs (params = 0 , dropout = jax .random .PRNGKey (42 ))
791793 # jax config
@@ -861,6 +863,8 @@ def test_indexer_match(self, seq_len=8):
861863 dtype = self .dtype ,
862864 batch_size = self .batch_size ,
863865 seq_len = self .seq_len ,
866+ attention = "dot_product" ,
867+ index_topk = 4 ,
864868 )
865869
866870 # Indexer specific RoPE (interleave=False)
@@ -906,17 +910,53 @@ class DeepseekV32MLATest(DeepseekTestBase):
906910 """Tests for MLA Attention with Sparse Indexing."""
907911
908912 @parameterized .named_parameters (
909- {"testcase_name" : "seq_len=2 (index_topk=4)" , "seq_len" : 2 },
910- {"testcase_name" : "seq_len=8 (index_topk=4)" , "seq_len" : 8 },
913+ {
914+ "testcase_name" : "dot_product_s2_k4" ,
915+ "attention" : "dot_product" ,
916+ "seq_len" : 2 ,
917+ "index_topk" : 4 ,
918+ },
919+ {
920+ "testcase_name" : "dot_product_s8_k4" ,
921+ "attention" : "dot_product" ,
922+ "seq_len" : 8 ,
923+ "index_topk" : 4 ,
924+ },
925+ {
926+ "testcase_name" : "dot_product_s128_k4" ,
927+ "attention" : "dot_product" ,
928+ "seq_len" : 128 ,
929+ "index_topk" : 4 ,
930+ "check_norm" : True ,
931+ },
932+ {
933+ "testcase_name" : "dot_product_s128_k128" ,
934+ "attention" : "dot_product" ,
935+ "seq_len" : 128 ,
936+ "index_topk" : 128 ,
937+ "check_norm" : True ,
938+ },
939+ {
940+ "testcase_name" : "flash_s128_k4" ,
941+ "attention" : "flash" ,
942+ "seq_len" : 128 ,
943+ "index_topk" : 4 ,
944+ "check_norm" : True ,
945+ },
946+ {
947+ "testcase_name" : "flash_s128_k128" ,
948+ "attention" : "flash" ,
949+ "seq_len" : 128 ,
950+ "index_topk" : 128 ,
951+ "check_norm" : True ,
952+ },
911953 )
912- # index_topk=4
913- def test_mla_match (self , seq_len = 8 ):
914- """Verifies MLA output (train mode) matches PyTorch (MHA mode) with indexer."""
915-
954+ def test_mla_parity (self , attention , seq_len , index_topk , check_norm = False ):
955+ """Verifies JAX MLA output against the PyTorch reference implementation."""
916956 torch_inputs , jax_inputs = self .get_data (seq_len )
917957
918958 # 1. PyTorch Run
919- pt_mla = MLA (self .pt_args )
959+ pt_mla = MLA (self .pt_args , index_topk )
920960 init_torch_weights (pt_mla )
921961 pt_mla .eval ()
922962
@@ -936,6 +976,8 @@ def test_mla_match(self, seq_len=8):
936976 dtype = self .dtype ,
937977 batch_size = self .batch_size ,
938978 seq_len = self .seq_len ,
979+ attention = attention ,
980+ index_topk = index_topk ,
939981 )
940982
941983 jax_mla = attention_mla .MLA (
@@ -959,7 +1001,7 @@ def test_mla_match(self, seq_len=8):
9591001 rope_factor = cfg .rope_factor ,
9601002 max_target_length = self .seq_len ,
9611003 mesh = mesh ,
962- attention_kernel = "dot_product" ,
1004+ attention_kernel = attention ,
9631005 inputs_q_shape = (self .batch_size , self .seq_len , cfg .emb_dim ),
9641006 inputs_kv_shape = (self .batch_size , self .seq_len , cfg .emb_dim ),
9651007 rngs = self .nnx_rng ,
@@ -976,10 +1018,17 @@ def test_mla_match(self, seq_len=8):
9761018 model_mode = MODEL_MODE_TRAIN ,
9771019 )
9781020
979- # 3 Compare
980- print ("torch out" , pt_out )
981- print ("jax out" , jax_out )
982- np .testing .assert_allclose (to_jax (pt_out ), jax_out , rtol = 1e-2 , atol = 1e-2 )
1021+ # 3. Compare
1022+ if check_norm :
1023+ expected = to_jax (pt_out ) / jnp .linalg .norm (to_jax (pt_out ))
1024+ actual = jax_out / jnp .linalg .norm (jax_out )
1025+ else :
1026+ expected = to_jax (pt_out )
1027+ actual = jax_out
1028+
1029+ print ("torch out" , expected )
1030+ print ("jax out" , actual )
1031+ np .testing .assert_allclose (expected , actual , rtol = 1e-2 , atol = 1e-2 )
9831032
9841033
9851034if __name__ == "__main__" :
0 commit comments