@@ -42,6 +42,7 @@ def gmm(
4242 use_qwix_quantization : bool = False ,
4343 use_tokamax_backend : bool = False ,
4444 weight_gather_axes : List [Tuple [str , int ]] | None = None ,
45+ input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
4546):
4647 """Grouped matrix multiplication operation."""
4748 quantization_rule = None
@@ -61,14 +62,15 @@ def gmm(
6162 )
6263
6364 gmm_fwd_bwd = lambda * args : _gmm_fwd (* args )[0 ] # pylint: disable=C3001
64- gmm_fwd_bwd = jax .custom_vjp (gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 7 , 8 , 9 , 10 , 11 ))
65+ gmm_fwd_bwd = jax .custom_vjp (gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 5 , 8 , 9 , 10 , 11 , 12 ))
6566 gmm_fwd_bwd .defvjp (_gmm_fwd , functools .partial (_gmm_bwd , lhs .dtype , rhs .dtype ))
6667 return gmm_fwd_bwd (
6768 lhs ,
6869 rhs ,
6970 group_sizes ,
7071 preferred_element_type ,
7172 tiling ,
73+ input_buffer_count ,
7274 group_offset ,
7375 existing_out ,
7476 transpose_rhs ,
@@ -85,6 +87,7 @@ def _gmm_fwd(
8587 group_sizes : jnp .ndarray ,
8688 preferred_element_type : jnp .dtype = jnp .float32 ,
8789 tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
90+ input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
8891 group_offset : jnp .ndarray | None = None ,
8992 existing_out : jnp .ndarray | None = None ,
9093 transpose_rhs : bool = False ,
@@ -125,9 +128,7 @@ def _gmm_fwd(
125128 # QAG is only supported for following conditions
126129 if use_tokamax_backend :
127130 if quantization_rule and quantization_rule .bwd_qtype :
128- if quantization_rule .weight_calibration_method .startswith (
129- "fixed"
130- ) and isinstance (rhs , qpl .QArray ):
131+ if quantization_rule .weight_calibration_method .startswith ("fixed" ) and isinstance (rhs , qpl .QArray ):
131132 if weight_gather_axes :
132133 for axis_name , axis_idx in weight_gather_axes :
133134 rhs_qvalue = jax .lax .all_gather (rhs .qvalue , axis_name , axis = axis_idx , tiled = True )
@@ -142,6 +143,7 @@ def _gmm_fwd(
142143 group_offset = group_offset ,
143144 transpose_rhs = transpose_rhs ,
144145 interpret = interpret ,
146+ input_buffer_count = input_buffer_count [0 ],
145147 )
146148 else :
147149 out = backend .gmm (
@@ -163,6 +165,7 @@ def _gmm_bwd(
163165 rhs_dtype : jax .typing .DTypeLike ,
164166 preferred_element_type : jnp .dtype ,
165167 tiling : tuple [int , int , int , int , int , int , int , int , int ],
168+ input_buffer_count : tuple [int , int , int ],
166169 transpose_rhs : bool ,
167170 interpret : bool ,
168171 quantization_rule : qwix .QtRule | None ,
@@ -229,6 +232,7 @@ def _gmm_bwd(
229232 group_offset = group_offset ,
230233 transpose_rhs = not transpose_rhs ,
231234 interpret = interpret ,
235+ input_buffer_count = input_buffer_count [1 ],
232236 )
233237 drhs = tokamax_backend .tgmm (
234238 lhs = lhs .swapaxes (0 , 1 ),
@@ -240,6 +244,7 @@ def _gmm_bwd(
240244 group_offset = group_offset ,
241245 num_actual_groups = num_actual_groups ,
242246 interpret = interpret ,
247+ input_buffer_count = input_buffer_count [2 ],
243248 )
244249 if quantization_rule and quantization_rule .bwd_qtype and weight_gather_axes :
245250 # Scatter back in reverse order of gather
0 commit comments