Skip to content

Commit 0a39e70

Browse files
Merge pull request #3041 from Steboss:main
PiperOrigin-RevId: 864463680
2 parents 8cc3ba7 + 7f13f98 commit 0a39e70

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/MaxText/layers/moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,9 +950,13 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
950950
# Use full contraction for QWIX quantization to allow quantization
951951
# fusion (max reduce over contracting dimension).
952952
tiling = (tiling[0], k, tiling[2])
953+
954+
is_tpu = (self.mesh.devices.flat[0] == "tpu")
955+
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
956+
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else '0'
953957
with set_xla_metadata(
954958
ragged_dot_tiling=",".join([str(t) for t in tiling]),
955-
mosaic_fusion_group=f"{random.randint(0, 1000000000)}",
959+
mosaic_fusion_group=mosaic_group_id,
956960
):
957961
output = jax.lax.ragged_dot(
958962
lhs=inputs,

0 commit comments

Comments
 (0)