@@ -854,45 +854,85 @@ def __call__(
854854 "slot" : slot ,
855855 }
856856 dense_layer = RemattedBlockLayers [0 ]
857- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
858- y , _ = self .scan_decoder_layers (
859- cfg ,
860- dense_layer ,
861- cfg .first_num_dense_layers ,
862- "dense_layers" ,
863- mesh ,
864- in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
865- model_mode = model_mode ,
866- )(y , * broadcast_args )
867857 moe_layer = RemattedBlockLayers [1 ]
868- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
869- num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
858+ if cfg .engram_layers :
859+ original_dense_call = dense_layer .__call__
860+ original_moe_call = moe_layer .__call__
861+ dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
862+ moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
863+
864+ common_kwargs = {
865+ "dense_layer" : dense_layer ,
866+ "moe_layer" : moe_layer ,
867+ "original_dense_call" : original_dense_call ,
868+ "original_moe_call" : original_moe_call ,
869+ "layer_call_kwargs" : layer_call_kwargs ,
870+ "decoder_segment_ids" : decoder_segment_ids ,
871+ "decoder_positions" : decoder_positions ,
872+ "deterministic" : deterministic ,
873+ "model_mode" : model_mode ,
874+ "decoder_input_tokens" : decoder_input_tokens ,
875+ "broadcast_args" : broadcast_args ,
876+ }
870877
871- # If batch-split schedule is used and initialization is complete,
872- # as detected by immutable params, use deepseek_batchsplit custom
873- # scan with initialized parameters.
874- if cfg .use_batch_split_schedule and not self .is_mutable_collection ("params" ):
875- y = deepseek_batchsplit .scan_batch_split_layers (
878+ # Apply Dense Layers
879+ y = self ._apply_interleaved_scanned_layers (
876880 y ,
877- self .variables ["params" ]["moe_layers" ],
878- decoder_positions ,
879- decoder_segment_ids ,
880- model_mode = model_mode ,
881- mesh = mesh ,
882- quant = self .quant ,
883- cfg = cfg ,
884- policy = policy ,
881+ layer_type = "dense" ,
882+ start_idx = 0 ,
883+ end_idx = cfg .first_num_dense_layers ,
884+ engram_indices = cfg .engram_layers ,
885+ ** common_kwargs ,
886+ )
887+
888+ # Apply MoE Layers
889+ y = self ._apply_interleaved_scanned_layers (
890+ y ,
891+ layer_type = "moe" ,
892+ start_idx = cfg .first_num_dense_layers ,
893+ end_idx = cfg .num_decoder_layers ,
894+ engram_indices = cfg .engram_layers ,
895+ ** common_kwargs ,
885896 )
886897 else :
898+ dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
887899 y , _ = self .scan_decoder_layers (
888900 cfg ,
889- moe_layer ,
890- num_moe_layers ,
891- "moe_layers " ,
901+ dense_layer ,
902+ cfg . first_num_dense_layers ,
903+ "dense_layers " ,
892904 mesh ,
893905 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
894906 model_mode = model_mode ,
895907 )(y , * broadcast_args )
908+ moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
909+ num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
910+
911+ # If batch-split schedule is used and initialization is complete,
912+ # as detected by immutable params, use deepseek_batchsplit custom
913+ # scan with initialized parameters.
914+ if cfg .use_batch_split_schedule and not self .is_mutable_collection ("params" ):
915+ y = deepseek_batchsplit .scan_batch_split_layers (
916+ y ,
917+ self .variables ["params" ]["moe_layers" ],
918+ decoder_positions ,
919+ decoder_segment_ids ,
920+ model_mode = model_mode ,
921+ mesh = mesh ,
922+ quant = self .quant ,
923+ cfg = cfg ,
924+ policy = policy ,
925+ )
926+ else :
927+ y , _ = self .scan_decoder_layers (
928+ cfg ,
929+ moe_layer ,
930+ num_moe_layers ,
931+ "moe_layers" ,
932+ mesh ,
933+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
934+ model_mode = model_mode ,
935+ )(y , * broadcast_args )
896936 elif cfg .decoder_block == DecoderBlockType .GEMMA3 :
897937 y = self ._apply_gemma3_scanned_blocks (
898938 y ,
@@ -1107,3 +1147,74 @@ def _apply_gemma3_scanned_blocks(
11071147 ** layer_call_kwargs ,
11081148 )
11091149 return y
1150+
1151+ # TODO(b/490118813): Relocate the following functions to their designated directories
1152+ # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
1153+ # _apply_scanned_chunk() and _apply_interleaved_scanned_layers().
1154+ def _find_next_boundary (self , current_idx , end_idx , engram_indices ):
1155+ """Finds the next index boundary, either the next Engram layer index or the overall end index."""
1156+ next_engrams = [l for l in engram_indices if l > current_idx ]
1157+ if next_engrams :
1158+ return min (end_idx , * next_engrams )
1159+ return end_idx
1160+
1161+ def _apply_single_engram_layer (self , y , current_idx , layer_type , ** kwargs ):
1162+ """Applies a single, unscanned Engram layer."""
1163+ layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
1164+ layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1165+ original_call = kwargs ["original_dense_call" ] if layer_type == "dense" else kwargs ["original_moe_call" ]
1166+ layer_call_kwargs = kwargs ["layer_call_kwargs" ]
1167+
1168+ layer .__call__ = original_call
1169+ y , _ = layer (
1170+ config = self .config ,
1171+ mesh = self .mesh ,
1172+ name = f"{ layer_prefix } _engram_{ current_idx } " ,
1173+ quant = self .quant ,
1174+ model_mode = self .model_mode ,
1175+ layer_idx = current_idx ,
1176+ )(
1177+ y ,
1178+ kwargs ["decoder_segment_ids" ],
1179+ kwargs ["decoder_positions" ],
1180+ kwargs ["deterministic" ],
1181+ kwargs ["model_mode" ],
1182+ decoder_input_tokens = kwargs ["decoder_input_tokens" ],
1183+ ** layer_call_kwargs ,
1184+ )
1185+ layer .__call__ = functools .partial (original_call , ** layer_call_kwargs )
1186+ return y
1187+
1188+ def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_type , ** kwargs ):
1189+ """Applies a contiguous chunk of layers using the scan operation."""
1190+ layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
1191+ layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1192+ broadcast_args = kwargs ["broadcast_args" ]
1193+ scan_length = next_boundary - current_idx
1194+
1195+ if scan_length > 0 :
1196+ y , _ = self .scan_decoder_layers (
1197+ self .config ,
1198+ layer ,
1199+ scan_length ,
1200+ f"{ layer_prefix } _{ current_idx } _{ next_boundary - 1 } " ,
1201+ self .mesh ,
1202+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
1203+ model_mode = kwargs ["model_mode" ],
1204+ )(y , * broadcast_args )
1205+ return y
1206+
1207+ def _apply_interleaved_scanned_layers (self , y , layer_type , start_idx , end_idx , engram_indices , ** kwargs ):
1208+ """Applies a mix of scanned standard layers and unscanned Engram layers."""
1209+ current_idx = start_idx
1210+ while current_idx < end_idx :
1211+ if current_idx in engram_indices :
1212+ # Handle individual unscanned Engram layer
1213+ y = self ._apply_single_engram_layer (y , current_idx , layer_type , ** kwargs )
1214+ current_idx += 1
1215+ else :
1216+ # Find next boundary and scan the chunk
1217+ next_boundary = self ._find_next_boundary (current_idx , end_idx , engram_indices )
1218+ y = self ._apply_scanned_chunk (y , current_idx , next_boundary , layer_type , ** kwargs )
1219+ current_idx = next_boundary
1220+ return y
0 commit comments