Skip to content

Commit 8d4f13a

Browse files
Merge pull request #3397 from AI-Hypercomputer:chengnuojin-enable-pipeline-2dfsdp
PiperOrigin-RevId: 883239276
2 parents 00ef5de + 4bc91e0 commit 8d4f13a

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -896,10 +896,14 @@ def sparse_matmul(
896896
def gmm(
897897
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898898
):
899-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
900-
group_sizes,
901-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
902-
)
899+
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
900+
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
901+
tokamax_group_sizes = group_sizes
902+
else:
903+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
904+
group_sizes,
905+
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
906+
)
903907
pad_length = self.config.wi_tile_fwd_batch_seq
904908
hs_shape = inputs.shape
905909
# pad length is the 1st dimension of tiling size in gmm call

0 commit comments

Comments
 (0)