Skip to content

Commit 1388c36

Browse files
committed
add another layer of custom vjp
1 parent 154c5d3 commit 1388c36

3 files changed

Lines changed: 227 additions & 241 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11271127
pre_bias_logits,
11281128
self.config.use_custom_sort_vjp,
11291129
roll_to_expert_id=num_experts_per_shard * expert_shard_id,
1130+
rngs=rngs,
11301131
)
11311132

11321133
# Filter down to the group sizes that apply to only the experts in the

src/maxtext/layers/pipeline.py

Lines changed: 117 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""Pipeline layer wrapping a decoder layer(s). Supports circular pipelining"""
1616

1717
from typing import Any
18+
import functools
1819

19-
import numpy as np
2020
from maxtext.utils import pipeline_utils
2121

2222
from jax import numpy as jnp
@@ -469,11 +469,8 @@ def permute_output_micro_per_stage_dim(self, output):
469469
# The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to
470470
# state_io - it will land on a different index of state_io depending on the number of iterations.
471471
microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage
472-
permutation = (
473-
np.arange(self.microbatches_per_stage) + microbatch_0_idx
474-
) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear
475-
# in idx 1, etc
476-
output = output[:, permutation]
472+
output = jnp.roll(output, shift=-microbatch_0_idx, axis=1)
473+
output = self._maybe_shard_with_logical(output, self.state_io_logical)
477474
return output
478475

479476
def get_current_stage_weights(
@@ -554,35 +551,116 @@ def gather_weights_for_stages_in(w, spec):
554551
repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params)
555552
return repeat_weights
556553

557-
def from_all_variables_to_bsw(self, weights, loop_iteration, physical_partition_spec):
554+
def from_all_variables_to_bsw(self, repeat_weights, physical_partition_spec):
558555
"""All gather one branch of bsw using shardmap."""
559-
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec)
560-
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
556+
557+
bsw_pps = pipeline_utils.generate_bsw_pps_from_pps(physical_partition_spec)
561558
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
562-
fsdp_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec)
559+
fsdp_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "fsdp")
560+
fsdpt_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "fsdp_transpose")
561+
expert_idx = pipeline_utils.get_fsdp_index_pytree(physical_partition_spec, "expert")
563562

564563
@jax.shard_map(
565564
mesh=self.mesh,
566-
in_specs=(repeat_weights_pps, None),
565+
in_specs=(repeat_weights_pps, None, None, None),
567566
out_specs=bsw_pps,
568567
check_vma=True,
569568
)
570-
def _all_gather_inner(sharded_weights, fsdp_idx):
571-
def _all_gather_invariant(x, i):
569+
def _all_gather_inner(sharded_weights, fsdp_idx, fsdpt_idx, expert_idx):
570+
def _all_gather_with_path(path, x, i, j, k):
572571
if i >= 0:
573-
return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
572+
x = all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
573+
if j >= 0:
574+
x = all_gather_invariant(x, axis_name="fsdp_transpose", axis=j - 1, tiled=True)
575+
# path_keys = [getattr(p, "key", str(p)) for p in path]
576+
is_moe_block = True # "MoeBlock_0" in path_keys TODO: Enable it
577+
if k >= 0 and not is_moe_block:
578+
x = all_gather_invariant(x, axis_name="expert", axis=k - 1, tiled=True)
574579
return x
575580

576-
return jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
581+
return jax.tree_util.tree_map_with_path(_all_gather_with_path, sharded_weights, fsdp_idx, fsdpt_idx, expert_idx)
577582

578-
return _all_gather_inner(repeat_weights, fsdp_idx)
583+
return _all_gather_inner(repeat_weights, fsdp_idx, fsdpt_idx, expert_idx)
579584

580585
def bsw_all_gather_over_fsdp(self, weights, physical_partition_spec, loop_iteration):
581586
"""All gather all bsw over fsdp mesh axis using shardmap."""
582-
bsw_0 = self.from_all_variables_to_bsw(weights, loop_iteration, physical_partition_spec)
583-
bsw_1 = self.from_all_variables_to_bsw(weights, loop_iteration + 1, physical_partition_spec)
587+
cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration, physical_partition_spec)
588+
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1, physical_partition_spec)
589+
bsw_0 = self.from_all_variables_to_bsw(cur_repeat_weights, physical_partition_spec)
590+
bsw_1 = self.from_all_variables_to_bsw(nxt_repeat_weights, physical_partition_spec)
584591
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
585592

593+
def _run_initialization(
594+
self,
595+
example_inputs,
596+
example_segmentation,
597+
example_position,
598+
segment_idx,
599+
position_idx,
600+
deterministic,
601+
model_mode,
602+
):
603+
"""Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
604+
vmap_func = self.get_vmap_func_for_init()
605+
606+
if self.config.num_pipeline_repeats > 1:
607+
# To shard the weights on initialization for the circular pipeline we create weights of
608+
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
609+
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
610+
vmap_func = nn.vmap(
611+
vmap_func,
612+
in_axes=(0, segment_idx, position_idx, None, None),
613+
variable_axes={
614+
"params": 0,
615+
"_overwrite_with_gradient": 0,
616+
"non_trainable": 0,
617+
"hyper_params": 0,
618+
},
619+
split_rngs={"params": True, "dropout": self.config.enable_dropout},
620+
metadata_params={
621+
nn.PARTITION_NAME: "circular_repeats",
622+
"sub_weight_split_dims_mapping": (None,),
623+
"is_initializing": True,
624+
"x_times": self.config.num_pipeline_repeats,
625+
"optimizer_dims_mapping": None,
626+
},
627+
)
628+
629+
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
630+
example_segmentation = (
631+
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
632+
if example_segmentation is not None
633+
else None
634+
)
635+
example_position = (
636+
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
637+
if example_position is not None
638+
else None
639+
)
640+
641+
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
642+
# the full total_iterations.
643+
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
644+
stage_outputs = vmap_func(
645+
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
646+
)
647+
if self.config.scan_layers:
648+
stage_outputs = stage_outputs[0]
649+
650+
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
651+
# which has shape [pipeline_microbatch_size, sequence, embed]
652+
if self.config.num_pipeline_repeats > 1:
653+
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
654+
broadcasted_stage_outpus = jax.lax.broadcast(
655+
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
656+
)
657+
658+
return jnp.reshape(
659+
broadcasted_stage_outpus,
660+
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
661+
out_sharding=self.output_sharding,
662+
)
663+
586664
def get_vmap_func_for_init(self):
587665
"""This vmap func is used to initialize the weights only on init."""
588666

@@ -741,13 +819,6 @@ def get_partition_spec_leaf(leaf):
741819
logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]}
742820
return logical_partition_spec
743821

744-
def _generate_bsw_pps_from_pps(self, physical_partition_spec):
745-
"""Create bsw physical partition spec from weight physical partition spec."""
746-
return jax.tree.map(
747-
lambda pps: P(*pipeline_utils.remove_fsdp_from_physical_partition_spec(pps)[1:]),
748-
physical_partition_spec,
749-
)
750-
751822
@nn.compact
752823
def __call__(
753824
self,
@@ -815,63 +886,8 @@ def __call__(
815886
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
816887

817888
if self.is_initializing():
818-
vmap_func = self.get_vmap_func_for_init()
819-
820-
if self.config.num_pipeline_repeats > 1:
821-
# To shard the weights on initialization for the circular pipeline we create weights of
822-
# shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis.
823-
# We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization.
824-
vmap_func = nn.vmap(
825-
vmap_func,
826-
in_axes=(0, segment_idx, position_idx, None, None),
827-
variable_axes={
828-
"params": 0,
829-
"_overwrite_with_gradient": 0,
830-
"non_trainable": 0,
831-
"hyper_params": 0,
832-
},
833-
split_rngs={"params": True, "dropout": self.config.enable_dropout},
834-
metadata_params={
835-
nn.PARTITION_NAME: "circular_repeats",
836-
"sub_weight_split_dims_mapping": (None,),
837-
"is_initializing": True,
838-
"x_times": self.config.num_pipeline_repeats,
839-
"optimizer_dims_mapping": None,
840-
},
841-
)
842-
843-
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
844-
example_segmentation = (
845-
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
846-
if example_segmentation is not None
847-
else None
848-
)
849-
example_position = (
850-
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
851-
if example_position is not None
852-
else None
853-
)
854-
# We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for
855-
# the full total_iterations.
856-
example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
857-
stage_outputs = vmap_func(
858-
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
859-
)
860-
if self.config.scan_layers:
861-
stage_outputs = stage_outputs[0]
862-
863-
# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output
864-
# which has shape [pipeline_microbatch_size, sequence, embed]
865-
if self.config.num_pipeline_repeats > 1:
866-
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
867-
broadcasted_stage_outpus = jax.lax.broadcast(
868-
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
869-
)
870-
871-
return jnp.reshape(
872-
broadcasted_stage_outpus,
873-
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
874-
out_sharding=self.output_sharding,
889+
return self._run_initialization(
890+
example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
875891
)
876892

877893
logical_partition_spec = pipeline_utils.get_logical_spec_repeats_removed(logical_partition_spec)
@@ -898,95 +914,35 @@ def run_iteration_scannable(model, loop_state):
898914
policy=self.get_pipeline_remat_policy(),
899915
)
900916

901-
def run_one_repeat_scannable(model, loop_state):
902-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
903-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
904-
)
905-
906-
if model.config.scan_pipeline_iterations:
907-
run_one_repeat_scanned_custom = pipeline_utils.create_scanned_function(
908-
model=model,
909-
run_iteration_scannable=run_iteration_scannable,
910-
length=model.config.num_pipeline_microbatches,
911-
variable_axes={
912-
"summaries": 0,
913-
"aux_loss": 0,
914-
"intermediates": 0,
915-
"hyper_params": 0,
916-
},
917-
split_rngs={"random": True},
918-
deterministic=deterministic,
919-
model_mode=model_mode,
920-
logical_partition_spec=logical_partition_spec,
921-
)
922-
loop_state = run_one_repeat_scanned_custom(loop_state, positions, segment_ids)
923-
else:
924-
for _ in range(model.config.num_pipeline_microbatches):
925-
loop_state, _ = run_iteration_scannable(model, loop_state)
926-
return loop_state, None
927-
928-
run_one_repeat_scannable = nn.remat(
929-
run_one_repeat_scannable,
930-
prevent_cse=not self.config.scan_pipeline_iterations,
931-
policy=self.get_pipeline_remat_policy(),
917+
base_scannable = functools.partial(
918+
pipeline_utils.create_run_scannable,
919+
model=self,
920+
run_iteration_scannable=run_iteration_scannable,
921+
deterministic=deterministic,
922+
model_mode=model_mode,
923+
logical_partition_spec=logical_partition_spec,
924+
physical_partition_spec=physical_partition_spec,
925+
positions=positions,
926+
segment_ids=segment_ids,
932927
)
933928

934-
def run_bubbles_scannable(model, loop_state):
935-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
936-
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
937-
)
938-
939-
if model.config.scan_pipeline_iterations:
940-
run_bubbles_scanned_custom = pipeline_utils.create_scanned_function(
941-
model=model,
942-
run_iteration_scannable=run_iteration_scannable,
943-
length=bubble_iterations,
944-
variable_axes={
945-
"summaries": 0,
946-
"aux_loss": 0,
947-
"intermediates": 0,
948-
"hyper_params": 0,
949-
},
950-
split_rngs={"random": True},
951-
deterministic=deterministic,
952-
model_mode=model_mode,
953-
logical_partition_spec=logical_partition_spec,
954-
)
955-
loop_state = run_bubbles_scanned_custom(loop_state, positions, segment_ids)
956-
else:
957-
for _ in range(model.config.num_pipeline_microbatches):
958-
loop_state, _ = run_iteration_scannable(model, loop_state)
959-
return loop_state, None
929+
run_one_repeat_scannable = base_scannable(
930+
length=self.config.num_pipeline_microbatches,
931+
)
960932

961-
run_bubbles_scannable = nn.remat(
962-
run_bubbles_scannable,
963-
prevent_cse=not self.config.scan_pipeline_iterations,
964-
policy=self.get_pipeline_remat_policy(),
933+
run_bubbles_scannable = base_scannable(
934+
length=bubble_iterations,
965935
)
966936

967937
def run_all_iterations(model, loop_state):
968938
if self.config.scan_pipeline_repeats:
969-
run_repeats_scanned = nn.scan(
970-
run_one_repeat_scannable,
971-
variable_axes={
972-
"summaries": 0,
973-
"aux_loss": 0,
974-
"intermediates": 0,
975-
"hyper_params": 0,
976-
},
977-
split_rngs={"random": True},
939+
run_repeats_scanned = pipeline_utils.create_run_repeats_scanned(
940+
run_scannable=run_one_repeat_scannable,
978941
length=model.config.num_pipeline_repeats,
979942
)
980943

981-
run_bubbles_scanned = nn.scan(
982-
run_bubbles_scannable,
983-
variable_axes={
984-
"summaries": 0,
985-
"aux_loss": 0,
986-
"intermediates": 0,
987-
"hyper_params": 0,
988-
},
989-
split_rngs={"random": True},
944+
run_bubbles_scanned = pipeline_utils.create_run_repeats_scanned(
945+
run_scannable=run_bubbles_scannable,
990946
length=1,
991947
)
992948
loop_state, _ = run_repeats_scanned(model, loop_state)

0 commit comments

Comments
 (0)