Skip to content

Commit 5bd542c

Browse files
authored
Update moe.py
1 parent 34ac013 commit 5bd542c

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -897,8 +897,7 @@ def gmm(
897897
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898898
):
899899
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
900-
if self.config.use_qwix_quantization or (self.config.using_pipeline_parallelism and
901-
self.config.pipeline_fsdp_ag_per_repeat):
900+
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
902901
tokamax_group_sizes = group_sizes
903902
else:
904903
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
@@ -935,7 +934,7 @@ def gmm(
935934
output = mblx.gmm(
936935
lhs=inputs,
937936
rhs=kernel,
938-
group_sizes=tokamax_group_sizes,
937+
group_sizes=group_sizes,
939938
preferred_element_type=self.dtype,
940939
tiling=tiling,
941940
lhs_quantize_dtype=lhs_quantize_dtype,

0 commit comments

Comments
 (0)