File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -804,16 +804,11 @@ def gmm(
804804 input_buffer_count ,
805805 combine_scopes ,
806806 ):
807-
808- tokamax_group_sizes = tokamax .RaggedDotGroupSizes (
809- group_sizes ,
810- max_utils .generate_representative_group_sizes (inputs .shape [0 ], kernel .shape [0 ]),
811- )
812807 if config .use_qwix_quantization :
813808 output = megablox .gmm (
814809 lhs = inputs ,
815810 rhs = kernel ,
816- group_sizes = tokamax_group_sizes ,
811+ group_sizes = group_sizes ,
817812 preferred_element_type = preferred_element_type ,
818813 tiling = tiling ,
819814 use_qwix_quantization = config .use_qwix_quantization ,
@@ -827,7 +822,10 @@ def gmm(
827822 output = tokamax .ragged_dot (
828823 lhs = inputs ,
829824 rhs = kernel ,
830- group_sizes = tokamax_group_sizes ,
825+ group_sizes = tokamax .RaggedDotGroupSizes (
826+ group_sizes ,
827+ max_utils .generate_representative_group_sizes (inputs .shape [0 ], kernel .shape [0 ]),
828+ ),
831829 precision = jax .lax .Precision .DEFAULT ,
832830 preferred_element_type = preferred_element_type ,
833831 implementation = "mosaic" ,
You can’t perform that action at this time.
0 commit comments