File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -53,7 +53,8 @@ def extract_logits(path, val):
5353 if not all_max_logits :
5454 return None
5555
56- return jnp .max (jnp .stack (all_max_logits ))
56+ # Compute max per layer first to handle potential shape mismatches
57+ return jnp .max (jnp .stack ([jnp .max (x ) for x in all_max_logits ]))
5758
5859
5960def apply_qk_clip (state , intermediate_outputs , config ):
Original file line number Diff line number Diff line change @@ -839,3 +839,31 @@ def test_engram_integration(self):
839839 "hf_access_token=fake" ,
840840 )
841841 )
842+
843+ @pytest .mark .cpu_only
844+ def test_qk_clip (self ):
845+ """AOT test for qk-clip with DeepSeek3 Tiny model"""
846+ compiled_trainstep_file = "/tmp/test_qk_clip.pickle"
847+ train_compile_main (
848+ (
849+ "" ,
850+ get_test_config_path (),
851+ f"compiled_trainstep_file={ compiled_trainstep_file } " ,
852+ "compile_topology=v5p-8" ,
853+ "compile_topology_num_slices=1" ,
854+ "model_name=deepseek3-tiny" ,
855+ "scan_layers=True" ,
856+ "sparse_matmul=True" ,
857+ "megablox=True" ,
858+ "use_tokamax_gmm=False" ,
859+ # TODO(agagik): update to flash after support
860+ "attention=dot_product" ,
861+ "use_tokamax_splash=True" ,
862+ "max_target_length=128" ,
863+ "per_device_batch_size=1" ,
864+ "dtype=bfloat16" ,
865+ "weight_dtype=float32" ,
866+ "use_qk_clip=true" ,
867+ "qk_clip_threshold=100" ,
868+ )
869+ )
You can’t perform that action at this time.
0 commit comments