Skip to content

Commit 47cd13a

Browse files
authored
Update moe.py
1 parent d5805b2 commit 47cd13a

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxtext/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,8 @@ def gmm(
897897
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898898
):
899899
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
900-
if self.config.use_qwix_quantization or (self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat):
900+
if self.config.use_qwix_quantization or (self.config.using_pipeline_parallelism and
901+
self.config.pipeline_fsdp_ag_per_repeat):
901902
tokamax_group_sizes = group_sizes
902903
else:
903904
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(

0 commit comments

Comments
 (0)