Skip to content

Commit c7355aa

Browse files
Merge pull request #3404 from ROCm:mla-test-ici-parallelism
PiperOrigin-RevId: 886875950
2 parents 4cf5bee + 2329d7e commit c7355aa

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

tests/unit/attention_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,7 @@ def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len):
16111611
def test_indexer_loss(self):
16121612
"""Test indexer loss computation."""
16131613
mla_config_args = self.config_arguments.copy()
1614+
mla_config_args.update(get_decoupled_parallelism_overrides())
16141615
mla_config_args["use_sparse_indexer"] = True
16151616
mla_config_args["attention"] = "dot_product"
16161617
_, mla = self.init_mla(mla_config_args, rope_type="default")
@@ -1657,6 +1658,7 @@ def test_indexer_loss(self):
16571658
def test_indexer_loss_kl_divergence_zero(self):
16581659
"""Test that KL divergence is 0 when target and pred distributions match exactly."""
16591660
mla_config_args = self.config_arguments.copy()
1661+
mla_config_args.update(get_decoupled_parallelism_overrides())
16601662
mla_config_args["use_sparse_indexer"] = True
16611663
mla_config_args["attention"] = "dot_product"
16621664
_, mla = self.init_mla(mla_config_args, rope_type="default")

0 commit comments

Comments
 (0)