Skip to content

Commit 23a82de

Browse files
Merge pull request #3277 from AI-Hypercomputer:agagik-qk-logits
PiperOrigin-RevId: 877549235
2 parents 305e4e0 + a92f703 commit 23a82de

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

src/maxtext/utils/qk_clip_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def extract_logits(path, val):
5353
if not all_max_logits:
5454
return None
5555

56-
return jnp.max(jnp.stack(all_max_logits))
56+
# Compute max per layer first to handle potential shape mismatches
57+
return jnp.max(jnp.stack([jnp.max(x) for x in all_max_logits]))
5758

5859

5960
def apply_qk_clip(state, intermediate_outputs, config):

tests/unit/train_compile_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,31 @@ def test_engram_integration(self):
839839
"hf_access_token=fake",
840840
)
841841
)
842+
843+
@pytest.mark.cpu_only
844+
def test_qk_clip(self):
845+
"""AOT test for qk-clip with DeepSeek3 Tiny model"""
846+
compiled_trainstep_file = "/tmp/test_qk_clip.pickle"
847+
train_compile_main(
848+
(
849+
"",
850+
get_test_config_path(),
851+
f"compiled_trainstep_file={compiled_trainstep_file}",
852+
"compile_topology=v5p-8",
853+
"compile_topology_num_slices=1",
854+
"model_name=deepseek3-tiny",
855+
"scan_layers=True",
856+
"sparse_matmul=True",
857+
"megablox=True",
858+
"use_tokamax_gmm=False",
859+
# TODO(agagik): update to flash after support
860+
"attention=dot_product",
861+
"use_tokamax_splash=True",
862+
"max_target_length=128",
863+
"per_device_batch_size=1",
864+
"dtype=bfloat16",
865+
"weight_dtype=float32",
866+
"use_qk_clip=true",
867+
"qk_clip_threshold=100",
868+
)
869+
)

0 commit comments

Comments
 (0)