Skip to content

Commit a3c19fd

Browse files
Merge pull request #3420 from AI-Hypercomputer:amandaliang
PiperOrigin-RevId: 884673900
2 parents 5cd1acb + dfe361d commit a3c19fd

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,16 +804,11 @@ def gmm(
804804
input_buffer_count,
805805
combine_scopes,
806806
):
807-
808-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
809-
group_sizes,
810-
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
811-
)
812807
if config.use_qwix_quantization:
813808
output = megablox.gmm(
814809
lhs=inputs,
815810
rhs=kernel,
816-
group_sizes=tokamax_group_sizes,
811+
group_sizes=group_sizes,
817812
preferred_element_type=preferred_element_type,
818813
tiling=tiling,
819814
use_qwix_quantization=config.use_qwix_quantization,
@@ -827,7 +822,10 @@ def gmm(
827822
output = tokamax.ragged_dot(
828823
lhs=inputs,
829824
rhs=kernel,
830-
group_sizes=tokamax_group_sizes,
825+
group_sizes=tokamax.RaggedDotGroupSizes(
826+
group_sizes,
827+
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
828+
),
831829
precision=jax.lax.Precision.DEFAULT,
832830
preferred_element_type=preferred_element_type,
833831
implementation="mosaic",

0 commit comments

Comments
 (0)