1515"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616
1717from typing import Any
18+ import functools
1819
19- import numpy as np
2020from maxtext .utils import pipeline_utils
2121
2222from jax import numpy as jnp
@@ -469,11 +469,8 @@ def permute_output_micro_per_stage_dim(self, output):
469469 # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
470470 # state_io - it will land on a different index of state_io depending on the number of iterations.
471471 microbatch_0_idx = self .iterations_to_complete_first_microbatch () % self .microbatches_per_stage
472- permutation = (
473- np .arange (self .microbatches_per_stage ) + microbatch_0_idx
474- ) % self .microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear
475- # in idx 1, etc
476- output = output [:, permutation ]
472+ output = jnp .roll (output , shift = - microbatch_0_idx , axis = 1 )
473+ output = self ._maybe_shard_with_logical (output , self .state_io_logical )
477474 return output
478475
479476 def get_current_stage_weights (
@@ -554,35 +551,116 @@ def gather_weights_for_stages_in(w, spec):
554551 repeat_weights = meta .remove_axis (weights , 0 , circular_metadata_params )
555552 return repeat_weights
556553
557- def from_all_variables_to_bsw (self , weights , loop_iteration , physical_partition_spec ):
554+ def from_all_variables_to_bsw (self , repeat_weights , physical_partition_spec ):
558555 """All gather one branch of bsw using shardmap."""
559- repeat_weights = self . from_all_variables_to_repeat_weights ( weights , loop_iteration , physical_partition_spec )
560- bsw_pps = self . _generate_bsw_pps_from_pps (physical_partition_spec )
556+
557+ bsw_pps = pipeline_utils . generate_bsw_pps_from_pps (physical_partition_spec )
561558 repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
562- fsdp_idx = pipeline_utils .get_fsdp_index_pytree (physical_partition_spec )
559+ fsdp_idx = pipeline_utils .get_fsdp_index_pytree (physical_partition_spec , "fsdp" )
560+ fsdpt_idx = pipeline_utils .get_fsdp_index_pytree (physical_partition_spec , "fsdp_transpose" )
561+ expert_idx = pipeline_utils .get_fsdp_index_pytree (physical_partition_spec , "expert" )
563562
564563 @jax .shard_map (
565564 mesh = self .mesh ,
566- in_specs = (repeat_weights_pps , None ),
565+ in_specs = (repeat_weights_pps , None , None , None ),
567566 out_specs = bsw_pps ,
568567 check_vma = True ,
569568 )
570- def _all_gather_inner (sharded_weights , fsdp_idx ):
571- def _all_gather_invariant ( x , i ):
569+ def _all_gather_inner (sharded_weights , fsdp_idx , fsdpt_idx , expert_idx ):
570+ def _all_gather_with_path ( path , x , i , j , k ):
572571 if i >= 0 :
573- return all_gather_invariant (x , axis_name = "fsdp" , axis = i - 1 , tiled = True )
572+ x = all_gather_invariant (x , axis_name = "fsdp" , axis = i - 1 , tiled = True )
573+ if j >= 0 :
574+ x = all_gather_invariant (x , axis_name = "fsdp_transpose" , axis = j - 1 , tiled = True )
575+ # path_keys = [getattr(p, "key", str(p)) for p in path]
576+ is_moe_block = True # "MoeBlock_0" in path_keys TODO: Enable it
577+ if k >= 0 and not is_moe_block :
578+ x = all_gather_invariant (x , axis_name = "expert" , axis = k - 1 , tiled = True )
574579 return x
575580
576- return jax .tree . map ( _all_gather_invariant , sharded_weights , fsdp_idx )
581+ return jax .tree_util . tree_map_with_path ( _all_gather_with_path , sharded_weights , fsdp_idx , fsdpt_idx , expert_idx )
577582
578- return _all_gather_inner (repeat_weights , fsdp_idx )
583+ return _all_gather_inner (repeat_weights , fsdp_idx , fsdpt_idx , expert_idx )
579584
580585 def bsw_all_gather_over_fsdp (self , weights , physical_partition_spec , loop_iteration ):
581586 """All gather all bsw over fsdp mesh axis using shardmap."""
582- bsw_0 = self .from_all_variables_to_bsw (weights , loop_iteration , physical_partition_spec )
583- bsw_1 = self .from_all_variables_to_bsw (weights , loop_iteration + 1 , physical_partition_spec )
587+ cur_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration , physical_partition_spec )
588+ nxt_repeat_weights = self .from_all_variables_to_repeat_weights (weights , loop_iteration + 1 , physical_partition_spec )
589+ bsw_0 = self .from_all_variables_to_bsw (cur_repeat_weights , physical_partition_spec )
590+ bsw_1 = self .from_all_variables_to_bsw (nxt_repeat_weights , physical_partition_spec )
584591 return jax .ad_checkpoint .checkpoint_name ((bsw_0 , bsw_1 ), "bsw" )
585592
593+ def _run_initialization (
594+ self ,
595+ example_inputs ,
596+ example_segmentation ,
597+ example_position ,
598+ segment_idx ,
599+ position_idx ,
600+ deterministic ,
601+ model_mode ,
602+ ):
603+ """Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
604+ vmap_func = self .get_vmap_func_for_init ()
605+
606+ if self .config .num_pipeline_repeats > 1 :
607+ # To shard the weights on initialization for the circular pipeline we create weights of
608+ # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
609+ # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
610+ vmap_func = nn .vmap (
611+ vmap_func ,
612+ in_axes = (0 , segment_idx , position_idx , None , None ),
613+ variable_axes = {
614+ "params" : 0 ,
615+ "_overwrite_with_gradient" : 0 ,
616+ "non_trainable" : 0 ,
617+ "hyper_params" : 0 ,
618+ },
619+ split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
620+ metadata_params = {
621+ nn .PARTITION_NAME : "circular_repeats" ,
622+ "sub_weight_split_dims_mapping" : (None ,),
623+ "is_initializing" : True ,
624+ "x_times" : self .config .num_pipeline_repeats ,
625+ "optimizer_dims_mapping" : None ,
626+ },
627+ )
628+
629+ example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
630+ example_segmentation = (
631+ jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
632+ if example_segmentation is not None
633+ else None
634+ )
635+ example_position = (
636+ jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
637+ if example_position is not None
638+ else None
639+ )
640+
641+ # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
642+ # the full total_iterations.
643+ example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
644+ stage_outputs = vmap_func (
645+ self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
646+ )
647+ if self .config .scan_layers :
648+ stage_outputs = stage_outputs [0 ]
649+
650+ # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
651+ # which has shape [pipeline_microbatch_size, sequence, embed]
652+ if self .config .num_pipeline_repeats > 1 :
653+ stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
654+ broadcasted_stage_outpus = jax .lax .broadcast (
655+ stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
656+ )
657+
658+ return jnp .reshape (
659+ broadcasted_stage_outpus ,
660+ [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
661+ out_sharding = self .output_sharding ,
662+ )
663+
586664 def get_vmap_func_for_init (self ):
587665 """This vmap func is used to initialize the weights only on init."""
588666
@@ -741,13 +819,6 @@ def get_partition_spec_leaf(leaf):
741819 logical_partition_spec = {"params" : partition_spec_with_extra_layer ["params" ]["layers" ]}
742820 return logical_partition_spec
743821
744- def _generate_bsw_pps_from_pps (self , physical_partition_spec ):
745- """Create bsw physical partition spec from weight physical partition spec."""
746- return jax .tree .map (
747- lambda pps : P (* pipeline_utils .remove_fsdp_from_physical_partition_spec (pps )[1 :]),
748- physical_partition_spec ,
749- )
750-
751822 @nn .compact
752823 def __call__ (
753824 self ,
@@ -815,63 +886,8 @@ def __call__(
815886 bubble_iterations = self .forwarding_delay * (self .num_stages - 1 )
816887
817888 if self .is_initializing ():
818- vmap_func = self .get_vmap_func_for_init ()
819-
820- if self .config .num_pipeline_repeats > 1 :
821- # To shard the weights on initialization for the circular pipeline we create weights of
822- # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
823- # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
824- vmap_func = nn .vmap (
825- vmap_func ,
826- in_axes = (0 , segment_idx , position_idx , None , None ),
827- variable_axes = {
828- "params" : 0 ,
829- "_overwrite_with_gradient" : 0 ,
830- "non_trainable" : 0 ,
831- "hyper_params" : 0 ,
832- },
833- split_rngs = {"params" : True , "dropout" : self .config .enable_dropout },
834- metadata_params = {
835- nn .PARTITION_NAME : "circular_repeats" ,
836- "sub_weight_split_dims_mapping" : (None ,),
837- "is_initializing" : True ,
838- "x_times" : self .config .num_pipeline_repeats ,
839- "optimizer_dims_mapping" : None ,
840- },
841- )
842-
843- example_inputs = jax .lax .broadcast (example_inputs , [self .config .num_pipeline_repeats ])
844- example_segmentation = (
845- jax .lax .broadcast (example_segmentation , [self .config .num_pipeline_repeats ])
846- if example_segmentation is not None
847- else None
848- )
849- example_position = (
850- jax .lax .broadcast (example_position , [self .config .num_pipeline_repeats ])
851- if example_position is not None
852- else None
853- )
854- # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
855- # the full total_iterations.
856- example_inputs = self ._maybe_shard_with_logical (example_inputs , (None , None , None , None ))
857- stage_outputs = vmap_func (
858- self .layers , example_inputs , example_segmentation , example_position , deterministic , model_mode
859- )
860- if self .config .scan_layers :
861- stage_outputs = stage_outputs [0 ]
862-
863- # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
864- # which has shape [pipeline_microbatch_size, sequence, embed]
865- if self .config .num_pipeline_repeats > 1 :
866- stage_outputs = stage_outputs [0 ] # Remove extra dimension created for the circular vmap
867- broadcasted_stage_outpus = jax .lax .broadcast (
868- stage_outputs [0 ], [self .config .micro_batch_size_to_train_on // self .pipeline_microbatch_size ]
869- )
870-
871- return jnp .reshape (
872- broadcasted_stage_outpus ,
873- [self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ],
874- out_sharding = self .output_sharding ,
889+ return self ._run_initialization (
890+ example_inputs , example_segmentation , example_position , segment_idx , position_idx , deterministic , model_mode
875891 )
876892
877893 logical_partition_spec = pipeline_utils .get_logical_spec_repeats_removed (logical_partition_spec )
@@ -898,95 +914,35 @@ def run_iteration_scannable(model, loop_state):
898914 policy = self .get_pipeline_remat_policy (),
899915 )
900916
901- def run_one_repeat_scannable (model , loop_state ):
902- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
903- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
904- )
905-
906- if model .config .scan_pipeline_iterations :
907- run_one_repeat_scanned_custom = pipeline_utils .create_scanned_function (
908- model = model ,
909- run_iteration_scannable = run_iteration_scannable ,
910- length = model .config .num_pipeline_microbatches ,
911- variable_axes = {
912- "summaries" : 0 ,
913- "aux_loss" : 0 ,
914- "intermediates" : 0 ,
915- "hyper_params" : 0 ,
916- },
917- split_rngs = {"random" : True },
918- deterministic = deterministic ,
919- model_mode = model_mode ,
920- logical_partition_spec = logical_partition_spec ,
921- )
922- loop_state = run_one_repeat_scanned_custom (loop_state , positions , segment_ids )
923- else :
924- for _ in range (model .config .num_pipeline_microbatches ):
925- loop_state , _ = run_iteration_scannable (model , loop_state )
926- return loop_state , None
927-
928- run_one_repeat_scannable = nn .remat (
929- run_one_repeat_scannable ,
930- prevent_cse = not self .config .scan_pipeline_iterations ,
931- policy = self .get_pipeline_remat_policy (),
917+ base_scannable = functools .partial (
918+ pipeline_utils .create_run_scannable ,
919+ model = self ,
920+ run_iteration_scannable = run_iteration_scannable ,
921+ deterministic = deterministic ,
922+ model_mode = model_mode ,
923+ logical_partition_spec = logical_partition_spec ,
924+ physical_partition_spec = physical_partition_spec ,
925+ positions = positions ,
926+ segment_ids = segment_ids ,
932927 )
933928
934- def run_bubbles_scannable (model , loop_state ):
935- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
936- loop_state ["weights" ], physical_partition_spec , loop_state ["loop_iteration" ]
937- )
938-
939- if model .config .scan_pipeline_iterations :
940- run_bubbles_scanned_custom = pipeline_utils .create_scanned_function (
941- model = model ,
942- run_iteration_scannable = run_iteration_scannable ,
943- length = bubble_iterations ,
944- variable_axes = {
945- "summaries" : 0 ,
946- "aux_loss" : 0 ,
947- "intermediates" : 0 ,
948- "hyper_params" : 0 ,
949- },
950- split_rngs = {"random" : True },
951- deterministic = deterministic ,
952- model_mode = model_mode ,
953- logical_partition_spec = logical_partition_spec ,
954- )
955- loop_state = run_bubbles_scanned_custom (loop_state , positions , segment_ids )
956- else :
957- for _ in range (model .config .num_pipeline_microbatches ):
958- loop_state , _ = run_iteration_scannable (model , loop_state )
959- return loop_state , None
929+ run_one_repeat_scannable = base_scannable (
930+ length = self .config .num_pipeline_microbatches ,
931+ )
960932
961- run_bubbles_scannable = nn .remat (
962- run_bubbles_scannable ,
963- prevent_cse = not self .config .scan_pipeline_iterations ,
964- policy = self .get_pipeline_remat_policy (),
933+ run_bubbles_scannable = base_scannable (
934+ length = bubble_iterations ,
965935 )
966936
967937 def run_all_iterations (model , loop_state ):
968938 if self .config .scan_pipeline_repeats :
969- run_repeats_scanned = nn .scan (
970- run_one_repeat_scannable ,
971- variable_axes = {
972- "summaries" : 0 ,
973- "aux_loss" : 0 ,
974- "intermediates" : 0 ,
975- "hyper_params" : 0 ,
976- },
977- split_rngs = {"random" : True },
939+ run_repeats_scanned = pipeline_utils .create_run_repeats_scanned (
940+ run_scannable = run_one_repeat_scannable ,
978941 length = model .config .num_pipeline_repeats ,
979942 )
980943
981- run_bubbles_scanned = nn .scan (
982- run_bubbles_scannable ,
983- variable_axes = {
984- "summaries" : 0 ,
985- "aux_loss" : 0 ,
986- "intermediates" : 0 ,
987- "hyper_params" : 0 ,
988- },
989- split_rngs = {"random" : True },
944+ run_bubbles_scanned = pipeline_utils .create_run_repeats_scanned (
945+ run_scannable = run_bubbles_scannable ,
990946 length = 1 ,
991947 )
992948 loop_state , _ = run_repeats_scanned (model , loop_state )
0 commit comments