Skip to content

Commit 73b295e

Browse files
committed
Skip MHC gmm tests for GPU
1 parent 9a6ff81 commit 73b295e

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

tests/unit/mhc_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import os.path
1818
import unittest
19+
import pytest
1920

2021
from flax import nnx
2122
from flax.linen import partitioning as nn_partitioning
@@ -118,6 +119,8 @@ def setUp(self):
118119
),
119120
)
120121

122+
# Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
123+
@pytest.mark.tpu_only
121124
def test_moe_layer_output_shape(self):
122125
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
123126
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)

0 commit comments

Comments
 (0)