1515
1616"""Alternative DeepSeek model definition with batch-split schedule."""
1717
18+ import dataclasses
1819import functools
1920import math
20- from typing import Sequence
21+ from typing import Any , Sequence
2122
2223import jax
2324import jax .numpy as jnp
24- from maxtext .kernels import megablox
25- from maxtext .kernels import sort_activations
25+ import qwix .pallas as qpl
26+ import tokamax
27+ from flax import linen as nn
28+
29+ from maxtext .kernels import megablox , sort_activations
2630from MaxText .layers import attention_op
31+ from MaxText .layers import moe as moe_lib
2732from MaxText .layers import quantizations
2833
2934
35+ @functools .partial (
36+ jax .custom_vjp ,
37+ nondiff_argnums = (
38+ 1 ,
39+ 2 ,
40+ 3 ,
41+ ),
42+ )
43+ def quantized_psum_scatter (x : jax .Array , axis_name : str , scatter_dimension : int , tiled : bool ) -> jax .Array :
44+ """Forward: Standard BF16 Reduce-Scatter.
45+
46+ Backward: Quantized FP8 All-Gather (DeepSeek optimization).
47+
48+ Args:
49+ x: The input tensor.
50+ axis_name: The axis name for the psum_scatter/all_gather operation.
51+ scatter_dimension: The dimension along which to scatter.
52+ tiled: Whether the scatter/gather is tiled.
53+
54+ Returns:
55+ The result of the reduce-scatter operation.
56+ """
57+ return _q_psum_scatter_fwd (x , axis_name , scatter_dimension , tiled )[0 ]
58+
59+
60+ def _q_psum_scatter_fwd (x : jax .Array , axis_name : str , scatter_dimension : int , tiled : bool ) -> tuple [jax .Array , None ]:
61+ out = jax .lax .psum_scatter (x , axis_name = axis_name , scatter_dimension = scatter_dimension , tiled = tiled )
62+ return out , None
63+
64+
65+ def _q_psum_scatter_bwd (
66+ axis_name : str ,
67+ scatter_dimension : int ,
68+ tiled : bool ,
69+ res : Any ,
70+ grads : jax .Array ,
71+ ) -> tuple [jax .Array ]: # pylint: disable=g-one-element-tuple
72+ """Backward pass for quantized_psum_scatter.
73+
74+ Performs a quantized All-Gather of the gradients.
75+
76+ Args:
77+ axis_name: The axis name for the all_gather operation.
78+ scatter_dimension: The dimension along which the scatter occurred in the
79+ forward pass.
80+ tiled: Whether the gather is tiled.
81+ res: The residuals from the forward pass (_q_psum_scatter_fwd), containing
82+ (axis_name, scatter_dimension, tiled).
83+ grads: The gradients from the next layer, which are in BF16.
84+
85+ Returns:
86+ The dequantized and all-gathered gradients.
87+ """
88+ del res
89+ # --- BACKWARD PASS (Dispatch) ---
90+ # 'grads' is the BF16 gradient arriving from the next layer.
91+ # We need to broadcast it back to all devices (All-Gather).
92+
93+ grads_q = qpl .quantize (
94+ grads ,
95+ jnp .float8_e5m2 ,
96+ channelwise_axes = [0 ],
97+ )
98+
99+ gathered_qvals = jax .lax .all_gather (grads_q .qvalue , axis_name = axis_name , tiled = tiled , axis = scatter_dimension )
100+
101+ return (qpl .dequantize (dataclasses .replace (grads_q , qvalue = gathered_qvals )),)
102+
103+
104+ quantized_psum_scatter .defvjp (_q_psum_scatter_fwd , _q_psum_scatter_bwd )
105+
106+
30107def fetch_weights (params , dtype ):
31108 """Fetches weights from params in the proper format for batch-split schedule."""
32109 return jax .tree .map (
@@ -164,29 +241,7 @@ def batch_split_schedule(
164241 routed_scaling_factor = cfg .routed_scaling_factor ,
165242 expert_axis_name = "expert" ,
166243 use_gather_mosaic_kernel = False ,
167- wi_tile_size = (
168- cfg .wi_tile_fwd_batch_seq ,
169- cfg .wi_tile_fwd_embed_dim ,
170- cfg .wi_tile_fwd_mlp_dim ,
171- cfg .wi_tile_dlhs_batch_seq ,
172- cfg .wi_tile_dlhs_embed_dim ,
173- cfg .wi_tile_dlhs_mlp_dim ,
174- cfg .wi_tile_drhs_batch_seq ,
175- cfg .wi_tile_drhs_embed_dim ,
176- cfg .wi_tile_drhs_mlp_dim ,
177- ),
178- wo_tile_size = (
179- cfg .wo_tile_fwd_batch_seq ,
180- cfg .wo_tile_fwd_embed_dim ,
181- cfg .wo_tile_fwd_mlp_dim ,
182- cfg .wo_tile_dlhs_batch_seq ,
183- cfg .wo_tile_dlhs_embed_dim ,
184- cfg .wo_tile_dlhs_mlp_dim ,
185- cfg .wo_tile_drhs_batch_seq ,
186- cfg .wo_tile_drhs_embed_dim ,
187- cfg .wo_tile_drhs_mlp_dim ,
188- ),
189- dtype = cfg .dtype ,
244+ config = cfg ,
190245 )
191246 xs = jax .shard_map (
192247 functools .partial (merge , split_factor = cfg .batch_split_factor ),
@@ -581,9 +636,7 @@ def moe(
581636 routed_scaling_factor ,
582637 expert_axis_name ,
583638 use_gather_mosaic_kernel ,
584- wi_tile_size ,
585- wo_tile_size ,
586- dtype ,
639+ config ,
587640):
588641 """Performs dropless MoE with tensor/expert parallelism."""
589642 xs , ys = list (zip (* inputs ))
@@ -597,9 +650,7 @@ def moe(
597650 routed_scaling_factor = routed_scaling_factor ,
598651 expert_axis_name = expert_axis_name ,
599652 use_gather_mosaic_kernel = use_gather_mosaic_kernel ,
600- wi_tile_size = wi_tile_size ,
601- wo_tile_size = wo_tile_size ,
602- dtype = dtype ,
653+ config = config ,
603654 ),
604655 mesh ,
605656 )
@@ -691,20 +742,133 @@ def unroute(
691742 return jax .lax .psum_scatter (x , expert_axis_name , scatter_dimension = 0 , tiled = True )
692743
693744
694- def compute (x , w0 , w1 , wo , group_sizes , weights , * , wi_tile_size , wo_tile_size , dtype ):
745+ def compute (x , w0 , w1 , wo , group_sizes , weights , * , config , mesh ):
695746 """Processes routed tokens through the MLP."""
696- gmm_fn = functools .partial (
697- megablox .gmm ,
698- group_sizes = group_sizes ,
699- preferred_element_type = dtype ,
747+
748+ def gmm (
749+ inputs ,
750+ kernel ,
751+ tiling ,
752+ group_sizes ,
753+ preferred_element_type ,
754+ weight_gather_axes ,
755+ input_buffer_count ,
756+ combine_scopes ,
757+ ):
758+ if config .use_qwix_quantization :
759+ output = megablox .gmm (
760+ lhs = inputs ,
761+ rhs = kernel ,
762+ group_sizes = group_sizes ,
763+ preferred_element_type = preferred_element_type ,
764+ tiling = tiling ,
765+ use_qwix_quantization = config .use_qwix_quantization ,
766+ use_tokamax_backend = config .use_tokamax_gmm ,
767+ weight_gather_axes = weight_gather_axes ,
768+ input_buffer_count = input_buffer_count ,
769+ combine_scopes = combine_scopes ,
770+ )
771+ else :
772+ output = tokamax .ragged_dot (
773+ lhs = inputs ,
774+ rhs = kernel ,
775+ group_sizes = group_sizes ,
776+ precision = jax .lax .Precision .DEFAULT ,
777+ preferred_element_type = preferred_element_type ,
778+ implementation = "mosaic" ,
779+ )
780+ return output
781+
782+ gmm_fn = functools .partial (gmm , group_sizes = group_sizes , preferred_element_type = config .dtype )
783+ wi_gather_axes = []
784+ wo_gather_axes = []
785+
786+ wi_tile_size = (
787+ config .wi_tile_fwd_batch_seq ,
788+ config .wi_tile_fwd_embed_dim ,
789+ config .wi_tile_fwd_mlp_dim ,
790+ config .wi_tile_dlhs_batch_seq ,
791+ config .wi_tile_dlhs_embed_dim ,
792+ config .wi_tile_dlhs_mlp_dim ,
793+ config .wi_tile_drhs_batch_seq ,
794+ config .wi_tile_drhs_embed_dim ,
795+ config .wi_tile_drhs_mlp_dim ,
796+ )
797+ wo_tile_size = (
798+ config .wo_tile_fwd_batch_seq ,
799+ config .wo_tile_fwd_embed_dim ,
800+ config .wo_tile_fwd_mlp_dim ,
801+ config .wo_tile_dlhs_batch_seq ,
802+ config .wo_tile_dlhs_embed_dim ,
803+ config .wo_tile_dlhs_mlp_dim ,
804+ config .wo_tile_drhs_batch_seq ,
805+ config .wo_tile_drhs_embed_dim ,
806+ config .wo_tile_drhs_mlp_dim ,
807+ )
808+ wi_input_buffer_count = (
809+ config .wi_tile_fwd_buffer_count ,
810+ config .wi_tile_dlhs_buffer_count ,
811+ config .wi_tile_drhs_buffer_count ,
812+ )
813+ wo_input_buffer_count = (
814+ config .wo_tile_fwd_buffer_count ,
815+ config .wo_tile_dlhs_buffer_count ,
816+ config .wo_tile_drhs_buffer_count ,
817+ )
818+
819+ wi_combine_scopes = config .wi_combine_scopes
820+ wo_combine_scopes = config .wo_combine_scopes
821+ if config .use_qwix_quantization :
822+ gating_pspec , linear_pspec = moe_lib .get_batchsplit_init_kernel_axes ()
823+ w0_pspec = nn .logical_to_mesh_axes (gating_pspec )
824+ wo_pspec = nn .logical_to_mesh_axes (linear_pspec )
825+ ignored_axes = ("expert" , "tensor" , "tensor_transpose" )
826+
827+ def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
828+ if pspec_dim_axes is None :
829+ return []
830+ axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
831+ active = []
832+ for ax in axes :
833+ if ax and ax not in ignored_axes and mesh .shape .get (ax , 1 ) > 1 :
834+ active .append ((ax , tensor_dim_index ))
835+ return active
836+
837+ wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [0 ], 0 ))
838+ wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [2 ], 2 ))
839+
840+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
841+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
842+
843+ layer_w0 = gmm_fn (
844+ x ,
845+ w0 ,
846+ tiling = wi_tile_size ,
847+ weight_gather_axes = wi_gather_axes ,
848+ input_buffer_count = wi_input_buffer_count ,
849+ combine_scopes = wi_combine_scopes ,
850+ )
851+ layer_w1 = gmm_fn (
852+ x ,
853+ w1 ,
854+ tiling = wi_tile_size ,
855+ weight_gather_axes = wi_gather_axes ,
856+ input_buffer_count = wi_input_buffer_count ,
857+ combine_scopes = wi_combine_scopes ,
700858 )
701- layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size )
702- layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size )
703859 layer_w0 = jax .ad_checkpoint .checkpoint_name (layer_w0 , "mlpwi_0" )
704860 layer_w1 = jax .ad_checkpoint .checkpoint_name (layer_w1 , "mlpwi_1" )
705861 intermediate_layer = jax .nn .silu (layer_w0 ) * layer_w1
706862 intermediate_layer *= weights [:, None ]
707- return gmm_fn (intermediate_layer , wo , tiling = wo_tile_size )
863+ layer_wo = gmm_fn (
864+ intermediate_layer ,
865+ wo ,
866+ tiling = wo_tile_size ,
867+ weight_gather_axes = wo_gather_axes ,
868+ input_buffer_count = wo_input_buffer_count ,
869+ combine_scopes = wo_combine_scopes ,
870+ )
871+ return layer_wo
708872
709873
710874def route_compute_unroute (
@@ -716,9 +880,8 @@ def route_compute_unroute(
716880 routed_scaling_factor ,
717881 expert_axis_name ,
718882 use_gather_mosaic_kernel ,
719- wi_tile_size ,
720- wo_tile_size ,
721- dtype ,
883+ config ,
884+ mesh ,
722885):
723886 """Routes, processes, and unroutes activations."""
724887 orig_shape = xs [0 ].shape
@@ -760,9 +923,8 @@ def compute_fn(inputs):
760923 routed_wo ,
761924 group_sizes ,
762925 weights ,
763- wi_tile_size = wi_tile_size ,
764- wo_tile_size = wo_tile_size ,
765- dtype = dtype ,
926+ config = config ,
927+ mesh = mesh ,
766928 )
767929 return x , y , selected_experts
768930
@@ -792,21 +954,21 @@ def process_activations(
792954 routed_scaling_factor ,
793955 expert_axis_name ,
794956 use_gather_mosaic_kernel ,
795- wi_tile_size ,
796- wo_tile_size ,
797- dtype ,
957+ config ,
798958):
799959 """Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights."""
800960 activation_pspec = jax .sharding .PartitionSpec (
801961 ("data" , "fsdp" , "fsdp_transpose" , "expert" , "context" ),
802962 None ,
803963 None ,
804964 )
805- gating_pspec , linear_pspec = (
806- jax .sharding .PartitionSpec (None , None , expert_axis_name ),
807- jax .sharding .PartitionSpec (None , expert_axis_name , None ),
808- )
809-
965+ if config .use_qwix_quantization :
966+ gating_pspec , linear_pspec = moe_lib .get_batchsplit_init_kernel_axes ()
967+ gating_pspec = nn .logical_to_mesh_axes (gating_pspec )
968+ linear_pspec = nn .logical_to_mesh_axes (linear_pspec )
969+ else :
970+ gating_pspec = jax .sharding .PartitionSpec (None , None , expert_axis_name )
971+ linear_pspec = jax .sharding .PartitionSpec (None , expert_axis_name , None )
810972 return jax .shard_map (
811973 functools .partial (
812974 route_compute_unroute ,
@@ -815,9 +977,8 @@ def process_activations(
815977 routed_scaling_factor = routed_scaling_factor ,
816978 expert_axis_name = expert_axis_name ,
817979 use_gather_mosaic_kernel = use_gather_mosaic_kernel ,
818- wi_tile_size = wi_tile_size ,
819- wo_tile_size = wo_tile_size ,
820- dtype = dtype ,
980+ config = config ,
981+ mesh = mesh ,
821982 ),
822983 mesh = mesh ,
823984 in_specs = (
@@ -841,4 +1002,4 @@ def process_activations(
8411002 ),
8421003 out_specs = activation_pspec ,
8431004 check_vma = False ,
844- )([x .astype (dtype ) for x in xs ], weights )
1005+ )([x .astype (config . dtype ) for x in xs ], weights )
0 commit comments