@@ -406,6 +406,7 @@ def batch_split_schedule(
406406 rope_factor = cfg .rope_factor ,
407407 mscale = cfg .mscale ,
408408 dtype = cfg .dtype ,
409+ quant = quant ,
409410 )
410411
411412 xs = moe (
@@ -418,6 +419,7 @@ def batch_split_schedule(
418419 expert_axis_name = "expert" ,
419420 use_gather_mosaic_kernel = False ,
420421 config = cfg ,
422+ quant = quant ,
421423 )
422424 return xs
423425
@@ -440,7 +442,21 @@ def with_data_parallel_constraint(x, mesh):
440442 return jax .lax .with_sharding_constraint (x , jax .NamedSharding (mesh , activation_pspec ))
441443
442444
443- def dot (x , y , axes = 1 ):
445+ def dot (x , y , quant = None , axes = 1 ):
446+ """Computes the dot product of two arrays, optionally using quantization."""
447+ if quant is not None :
448+ # Convert axes to jax.lax.dot_general dimension_numbers
449+ if isinstance (axes , int ):
450+ x_contract = tuple (range (x .ndim - axes , x .ndim ))
451+ y_contract = tuple (range (axes ))
452+ else :
453+ x_contract , y_contract = axes
454+ dimension_numbers = ((x_contract , y_contract ), ((), ()))
455+ # Instantiate and call qwix dot_general
456+ custom_dot = quant .dot_general_cls ()()
457+ return custom_dot (lhs = x , rhs = y , dimension_numbers = dimension_numbers )
458+
459+ # Unquantized
444460 return jnp .tensordot (x , y , axes = axes )
445461
446462
@@ -466,6 +482,7 @@ def mla_with_norms(
466482 rope_factor ,
467483 mscale ,
468484 dtype ,
485+ quant ,
469486):
470487 """Performs MLA with pre- and post-normalization."""
471488 (pre_attn_scale , post_attn_scale ), attn_ws = weights
@@ -500,6 +517,7 @@ def fn(args):
500517 dtype = dtype ,
501518 mscale = mscale ,
502519 attention_op_fn = attn_op ,
520+ quant = quant ,
503521 ),
504522 mesh ,
505523 )
@@ -535,6 +553,7 @@ def mla(
535553 mscale ,
536554 attention_op_fn ,
537555 dtype ,
556+ quant ,
538557):
539558 """Performs MLA."""
540559 (
@@ -563,6 +582,7 @@ def mla(
563582 dtype = dtype ,
564583 qk_nope_head_dim = qk_nope_head_dim ,
565584 mscale = mscale ,
585+ quant = quant ,
566586 )
567587 query = jax .ad_checkpoint .checkpoint_name (query , "query_proj" )
568588 key , value = kv_projection (
@@ -583,6 +603,7 @@ def mla(
583603 dtype = dtype ,
584604 qk_nope_head_dim = qk_nope_head_dim ,
585605 num_query_heads = num_query_heads ,
606+ quant = quant ,
586607 )
587608 key = jax .ad_checkpoint .checkpoint_name (key , "key_proj" )
588609 value = jax .ad_checkpoint .checkpoint_name (value , "value_proj" )
@@ -595,7 +616,7 @@ def mla(
595616 cached_values = [None , None ],
596617 )
597618 out = jax .ad_checkpoint .checkpoint_name (out , "attention_out" )
598- out = dot (out , out_weights , axes = 2 )
619+ out = dot (out , out_weights , quant = quant , axes = 2 )
599620 out = jax .ad_checkpoint .checkpoint_name (out , "out_proj" )
600621 return out
601622
@@ -618,6 +639,7 @@ def query_projection(
618639 rope_factor ,
619640 dtype ,
620641 mscale ,
642+ quant ,
621643):
622644 """Performs query projection."""
623645 # Set softmax scaling.
@@ -628,15 +650,15 @@ def query_projection(
628650 softmax_scale = softmax_scale * m * m
629651
630652 # LoRA path
631- low_rank_q = dot (inputs_q , wq_a_weights )
653+ low_rank_q = dot (inputs_q , wq_a_weights , quant = quant )
632654 low_rank_q = rms_norm (
633655 low_rank_q ,
634656 q_norm_scale_weights ,
635657 epsilon = epsilon ,
636658 dtype = dtype ,
637659 )
638660 low_rank_q = jax .ad_checkpoint .checkpoint_name (low_rank_q , "mla_q" )
639- q = dot (low_rank_q , wq_b_weights )
661+ q = dot (low_rank_q , wq_b_weights , quant = quant )
640662
641663 # Split into non-positional and rotary parts.
642664 q_nope , q_pe = jnp .split (q , [qk_nope_head_dim ], axis = - 1 )
@@ -675,9 +697,10 @@ def kv_projection(
675697 dtype ,
676698 qk_nope_head_dim ,
677699 num_query_heads ,
700+ quant ,
678701):
679702 """Performs KV projection."""
680- low_rank = dot (inputs , wkv_a_weights )
703+ low_rank = dot (inputs , wkv_a_weights , quant = quant )
681704 low_rank_main , low_rank_rope = jnp .split (low_rank , [kv_lora_rank ], axis = - 1 )
682705 low_rank_main = rms_norm (
683706 low_rank_main ,
@@ -706,12 +729,13 @@ def kv_projection(
706729 wkv_b_weights ,
707730 qk_nope_head_dim = qk_nope_head_dim ,
708731 num_query_heads = num_query_heads ,
732+ quant = quant ,
709733 )
710734
711735
712- def get_key_value (low_rank_main , key_rope , wkv_b_weights , * , qk_nope_head_dim , num_query_heads ):
736+ def get_key_value (low_rank_main , key_rope , wkv_b_weights , * , qk_nope_head_dim , num_query_heads , quant ):
713737 """Gets key and value from compressed KV latent vector and key rope."""
714- kv_out = dot (low_rank_main , wkv_b_weights )
738+ kv_out = dot (low_rank_main , wkv_b_weights , quant = quant )
715739
716740 # Split kv_out into key_nope and value parts.
717741 key_nope , value = jnp .split (kv_out , [qk_nope_head_dim ], axis = - 1 )
@@ -807,6 +831,7 @@ def moe(
807831 expert_axis_name ,
808832 use_gather_mosaic_kernel ,
809833 config ,
834+ quant ,
810835):
811836 """Performs dropless MoE with tensor/expert parallelism."""
812837 xs , ys = list (zip (* inputs ))
@@ -821,6 +846,7 @@ def moe(
821846 expert_axis_name = expert_axis_name ,
822847 use_gather_mosaic_kernel = use_gather_mosaic_kernel ,
823848 config = config ,
849+ quant = quant ,
824850 ),
825851 mesh ,
826852 )
@@ -851,9 +877,10 @@ def expert_selection(
851877 num_experts ,
852878 num_experts_per_tok ,
853879 routed_scaling_factor ,
880+ quant ,
854881):
855882 """Selects experts for each token and calculates group sizes for each expert."""
856- pre_bias_logits = jax .nn .sigmoid (dot (x , routing_kernel ))
883+ pre_bias_logits = jax .nn .sigmoid (dot (x , routing_kernel , quant = quant ))
857884 logits = pre_bias_logits + routing_bias
858885
859886 selected_experts , weights = expert_indices_and_weights (
@@ -1067,6 +1094,7 @@ def route_compute_unroute(
10671094 use_gather_mosaic_kernel ,
10681095 config ,
10691096 mesh ,
1097+ quant ,
10701098):
10711099 """Routes, processes, and unroutes activations."""
10721100 orig_shape = xs [0 ].shape
@@ -1078,7 +1106,9 @@ def route_compute_unroute(
10781106
10791107 def route_fn (inputs ):
10801108 # Shared expert.
1081- y = dot (jax .nn .silu (dot (inputs , shared_w0 )) * dot (inputs , shared_w1 ), shared_wo )
1109+ y = dot (
1110+ jax .nn .silu (dot (inputs , shared_w0 , quant = quant )) * dot (inputs , shared_w1 , quant = quant ), shared_wo , quant = quant
1111+ )
10821112
10831113 inputs = jnp .reshape (inputs , (- 1 , inputs .shape [- 1 ]))
10841114 selected_experts , weights , group_sizes = expert_selection (
@@ -1088,6 +1118,7 @@ def route_fn(inputs):
10881118 num_experts = num_experts ,
10891119 num_experts_per_tok = num_experts_per_tok ,
10901120 routed_scaling_factor = routed_scaling_factor ,
1121+ quant = quant ,
10911122 )
10921123 x , selected_experts , weights , group_sizes = route (
10931124 inputs ,
@@ -1140,6 +1171,7 @@ def process_activations(
11401171 expert_axis_name ,
11411172 use_gather_mosaic_kernel ,
11421173 config ,
1174+ quant ,
11431175):
11441176 """Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights."""
11451177 activation_pspec = jax .sharding .PartitionSpec (
@@ -1164,6 +1196,7 @@ def process_activations(
11641196 use_gather_mosaic_kernel = use_gather_mosaic_kernel ,
11651197 config = config ,
11661198 mesh = mesh ,
1199+ quant = quant ,
11671200 ),
11681201 mesh = mesh ,
11691202 in_specs = (
0 commit comments