Skip to content

Commit 76b193b

Browse files
committed
add all gather weights over expert axes in attention component
1 parent 1388c36 commit 76b193b

2 files changed

Lines changed: 53 additions & 4 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ def _all_gather_with_path(path, x, i, j, k):
572572
x = all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
573573
if j >= 0:
574574
x = all_gather_invariant(x, axis_name="fsdp_transpose", axis=j - 1, tiled=True)
575-
# path_keys = [getattr(p, "key", str(p)) for p in path]
576-
is_moe_block = True # "MoeBlock_0" in path_keys TODO: Enable it
575+
path_keys = [getattr(p, "key", str(p)) for p in path]
576+
is_moe_block = "MoeBlock_0" in path_keys
577577
if k >= 0 and not is_moe_block:
578578
x = all_gather_invariant(x, axis_name="expert", axis=k - 1, tiled=True)
579579
return x

src/maxtext/utils/pipeline_utils.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,61 @@ def find_fsdp(pspec):
5454

5555
def generate_bsw_pps_from_pps(physical_partition_spec):
5656
"""Create bsw physical partition spec from weight physical partition spec."""
57-
return jax.tree.map(
58-
lambda pps: P(*remove_fsdp_from_physical_partition_spec(pps)[1:]),
57+
58+
def _process_pps(path, pps):
59+
# Extract string keys from the JAX KeyPath elements safely
60+
path_keys = [getattr(p, "key", str(p)) for p in path]
61+
is_moe_block_0 = "MoeBlock_0" in path_keys
62+
63+
# Remove the gathered axes conditionally based on the path
64+
processed_pps = remove_gathered_axes_from_physical_partition_spec(pps, is_moe_block_0)
65+
66+
# Keep the original [1:] slicing behavior (e.g., to drop the 'stage' axis)
67+
return P(*processed_pps[1:])
68+
69+
return jax.tree_util.tree_map_with_path(
70+
_process_pps,
5971
physical_partition_spec,
6072
)
6173

6274

75+
def remove_gathered_axes_from_physical_partition_spec(pps, is_moe_block_0):
76+
"""Removes 'fsdp', 'fsdp_transpose', and conditionally 'expert' from a physical PartitionSpec."""
77+
78+
# Always remove fsdp and fsdp_transpose as they are always gathered
79+
axes_to_remove = ["fsdp", "fsdp_transpose"]
80+
81+
# Only remove 'expert' if we are NOT in MoeBlock_0
82+
if not is_moe_block_0:
83+
axes_to_remove.append("expert")
84+
85+
if isinstance(pps, P):
86+
new_spec = []
87+
# Iterate through each axis in the original PartitionSpec.
88+
for axis in pps:
89+
if axis is None:
90+
new_spec.append(None)
91+
elif isinstance(axis, str):
92+
# If the axis is in our removal list, replace it with None to signify replication.
93+
if axis not in axes_to_remove:
94+
new_spec.append(axis)
95+
else:
96+
new_spec.append(None)
97+
elif isinstance(axis, (list, tuple)):
98+
# If the axis is a collection, filter out the gathered axes.
99+
new_axis = [a for a in axis if a not in axes_to_remove]
100+
# If all elements are filtered out, new_axis becomes [], which as a tuple ()
101+
# correctly signals replication across those mesh axes in JAX.
102+
new_spec.append(tuple(new_axis))
103+
else:
104+
raise ValueError(f"Unsupported_axis_type: {type(axis)}")
105+
106+
# Return a new sharding object with the modified spec.
107+
return P(*new_spec)
108+
109+
return pps
110+
111+
63112
def get_logical_spec_repeats_removed(full_logical):
64113
"""Removes 'circular_repeats' from logical partition spec."""
65114
if full_logical is None:

0 commit comments

Comments
 (0)