@@ -854,45 +854,85 @@ def __call__(
854854 "slot" : slot ,
855855 }
856856 dense_layer = RemattedBlockLayers [0 ]
857- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
858- y , _ = self .scan_decoder_layers (
859- cfg ,
860- dense_layer ,
861- cfg .first_num_dense_layers ,
862- "dense_layers" ,
863- mesh ,
864- in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
865- model_mode = model_mode ,
866- )(y , * broadcast_args )
867857 moe_layer = RemattedBlockLayers [1 ]
868- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
869- num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
858+ if cfg .engram_layers :
859+ original_dense_call = dense_layer .__call__
860+ original_moe_call = moe_layer .__call__
861+ dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
862+ moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
863+
864+ common_kwargs = {
865+ "dense_layer" : dense_layer ,
866+ "moe_layer" : moe_layer ,
867+ "original_dense_call" : original_dense_call ,
868+ "original_moe_call" : original_moe_call ,
869+ "layer_call_kwargs" : layer_call_kwargs ,
870+ "decoder_segment_ids" : decoder_segment_ids ,
871+ "decoder_positions" : decoder_positions ,
872+ "deterministic" : deterministic ,
873+ "model_mode" : model_mode ,
874+ "decoder_input_tokens" : decoder_input_tokens ,
875+ "broadcast_args" : broadcast_args ,
876+ }
870877
871- # If batch-split schedule is used and initialization is complete,
872- # as detected by immutable params, use deepseek_batchsplit custom
873- # scan with initialized parameters.
874- if cfg .use_batch_split_schedule and not self .is_mutable_collection ("params" ):
875- y = deepseek_batchsplit .scan_batch_split_layers (
878+ # Apply Dense Layers
879+ y = self ._apply_interleaved_scanned_layers (
876880 y ,
877- self .variables ["params" ]["moe_layers" ],
878- decoder_positions ,
879- decoder_segment_ids ,
880- model_mode = model_mode ,
881- mesh = mesh ,
882- quant = self .quant ,
883- cfg = cfg ,
884- policy = policy ,
881+ layer_type = "dense" ,
882+ start_idx = 0 ,
883+ end_idx = cfg .first_num_dense_layers ,
884+ engram_indices = cfg .engram_layers ,
885+ ** common_kwargs ,
886+ )
887+
888+ # Apply MoE Layers
889+ y = self ._apply_interleaved_scanned_layers (
890+ y ,
891+ layer_type = "moe" ,
892+ start_idx = cfg .first_num_dense_layers ,
893+ end_idx = cfg .num_decoder_layers ,
894+ engram_indices = cfg .engram_layers ,
895+ ** common_kwargs ,
885896 )
886897 else :
898+ dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
887899 y , _ = self .scan_decoder_layers (
888900 cfg ,
889- moe_layer ,
890- num_moe_layers ,
891- "moe_layers " ,
901+ dense_layer ,
902+ cfg . first_num_dense_layers ,
903+ "dense_layers " ,
892904 mesh ,
893905 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
894906 model_mode = model_mode ,
895907 )(y , * broadcast_args )
908+ moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
909+ num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
910+
911+ # If batch-split schedule is used and initialization is complete,
912+ # as detected by immutable params, use deepseek_batchsplit custom
913+ # scan with initialized parameters.
914+ if cfg .use_batch_split_schedule and not self .is_mutable_collection ("params" ):
915+ y = deepseek_batchsplit .scan_batch_split_layers (
916+ y ,
917+ self .variables ["params" ]["moe_layers" ],
918+ decoder_positions ,
919+ decoder_segment_ids ,
920+ model_mode = model_mode ,
921+ mesh = mesh ,
922+ quant = self .quant ,
923+ cfg = cfg ,
924+ policy = policy ,
925+ )
926+ else :
927+ y , _ = self .scan_decoder_layers (
928+ cfg ,
929+ moe_layer ,
930+ num_moe_layers ,
931+ "moe_layers" ,
932+ mesh ,
933+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
934+ model_mode = model_mode ,
935+ )(y , * broadcast_args )
896936 elif cfg .decoder_block == DecoderBlockType .GEMMA3 :
897937 y = self ._apply_gemma3_scanned_blocks (
898938 y ,
@@ -1118,3 +1158,74 @@ def _apply_gemma3_scanned_blocks(
11181158 ** layer_call_kwargs ,
11191159 )
11201160 return y
1161+
1162+ # TODO(b/490118813): Relocate the following functions to their designated directories
1163+ # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
1164+ # _apply_scanned_chunk() and _apply_interleaved_scanned_layers().
1165+ def _find_next_boundary (self , current_idx , end_idx , engram_indices ):
1166+ """Finds the next index boundary, either the next Engram layer index or the overall end index."""
1167+ next_engrams = [l for l in engram_indices if l > current_idx ]
1168+ if next_engrams :
1169+ return min (end_idx , * next_engrams )
1170+ return end_idx
1171+
1172+ def _apply_single_engram_layer (self , y , current_idx , layer_type , ** kwargs ):
1173+ """Applies a single, unscanned Engram layer."""
1174+ layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
1175+ layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176+ original_call = kwargs ["original_dense_call" ] if layer_type == "dense" else kwargs ["original_moe_call" ]
1177+ layer_call_kwargs = kwargs ["layer_call_kwargs" ]
1178+
1179+ layer .__call__ = original_call
1180+ y , _ = layer (
1181+ config = self .config ,
1182+ mesh = self .mesh ,
1183+ name = f"{ layer_prefix } _engram_{ current_idx } " ,
1184+ quant = self .quant ,
1185+ model_mode = self .model_mode ,
1186+ layer_idx = current_idx ,
1187+ )(
1188+ y ,
1189+ kwargs ["decoder_segment_ids" ],
1190+ kwargs ["decoder_positions" ],
1191+ kwargs ["deterministic" ],
1192+ kwargs ["model_mode" ],
1193+ decoder_input_tokens = kwargs ["decoder_input_tokens" ],
1194+ ** layer_call_kwargs ,
1195+ )
1196+ layer .__call__ = functools .partial (original_call , ** layer_call_kwargs )
1197+ return y
1198+
1199+ def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_type , ** kwargs ):
1200+ """Applies a contiguous chunk of layers using the scan operation."""
1201+ layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
1202+ layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1203+ broadcast_args = kwargs ["broadcast_args" ]
1204+ scan_length = next_boundary - current_idx
1205+
1206+ if scan_length > 0 :
1207+ y , _ = self .scan_decoder_layers (
1208+ self .config ,
1209+ layer ,
1210+ scan_length ,
1211+ f"{ layer_prefix } _{ current_idx } _{ next_boundary - 1 } " ,
1212+ self .mesh ,
1213+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
1214+ model_mode = kwargs ["model_mode" ],
1215+ )(y , * broadcast_args )
1216+ return y
1217+
1218+ def _apply_interleaved_scanned_layers (self , y , layer_type , start_idx , end_idx , engram_indices , ** kwargs ):
1219+ """Applies a mix of scanned standard layers and unscanned Engram layers."""
1220+ current_idx = start_idx
1221+ while current_idx < end_idx :
1222+ if current_idx in engram_indices :
1223+ # Handle individual unscanned Engram layer
1224+ y = self ._apply_single_engram_layer (y , current_idx , layer_type , ** kwargs )
1225+ current_idx += 1
1226+ else :
1227+ # Find next boundary and scan the chunk
1228+ next_boundary = self ._find_next_boundary (current_idx , end_idx , engram_indices )
1229+ y = self ._apply_scanned_chunk (y , current_idx , next_boundary , layer_type , ** kwargs )
1230+ current_idx = next_boundary
1231+ return y
0 commit comments