Skip to content

Commit d5805b2

Browse files
authored
Skip Tokamax RaggedDotGroupSizes for FP8
# Description FP8 path is still using tokamax internal backend APIs. The new `RaggedDotGroupSizes` was introduced ([pull3330](#3330)) for Tokamax public APIs in bf16 path, which broke FP8. # Tests Benchmarks were run internally. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [X] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [X] I have necessary comments in my code, particularly in hard-to-understand areas. - [X] I have run end-to-end tests tests and provided workload links above if applicable. - [X] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files).
1 parent 5d9e57f commit d5805b2

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def gmm(
897897
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898898
):
899899
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
900-
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
900+
if self.config.use_qwix_quantization or (self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat):
901901
tokamax_group_sizes = group_sizes
902902
else:
903903
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(

0 commit comments

Comments
 (0)