Skip to content

Commit e1a2ba7

Browse files
Merge pull request #3438 from AI-Hypercomputer:chengnuojin-custom-logical
PiperOrigin-RevId: 886413492
2 parents 0efc6ca + 307bc11 commit e1a2ba7

8 files changed

Lines changed: 200 additions & 16 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ internal_compile_num_devices: -1 # You must specify the number of devices when u
431431

432432
# Parallelism
433433
shard_mode: "auto" # can be either auto or explicit
434+
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
434435
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
435436
logical_axis_rules: [
436437
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
16+
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
17+
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
18+
# EP x FSDP > 512.
19+
#
20+
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
21+
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
22+
# second, it may be required for DCN communication.
23+
#
24+
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
25+
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
26+
# store prefetched weights.
27+
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
28+
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
29+
logical_axis_rules: [
30+
['activation_batch', ['data', 'fsdp', 'expert']],
31+
['activation_batch_no_exp', ['data', 'fsdp']],
32+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
33+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
34+
['activation_heads', ['tensor']],
35+
['activation_kv_heads', ['tensor']],
36+
['activation_length', ['expert']],
37+
['activation_attn_length', ['expert']],
38+
['activation_q_length', ['expert']],
39+
['activation_attn_embed', ['tensor']],
40+
['activation_embed', ['tensor']],
41+
['activation_mlp', ['tensor']],
42+
['activation_kv', ['tensor']],
43+
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
44+
['activation_kv_batch', ['data', 'fsdp', 'expert']],
45+
['activation_kv_batch_no_exp', ['data', 'fsdp']],
46+
['activation_kv_head_dim', ['tensor']],
47+
['activation_vocab', ['tensor']],
48+
['activation_stage', 'stage'],
49+
['activation_exp', ['expert']],
50+
['decode_batch', ['data', 'fsdp', 'expert']],
51+
['mlp', ['tensor']],
52+
['mlp_no_fsdp', ['tensor']],
53+
['vocab', ['tensor']],
54+
['heads', ['tensor']],
55+
['q_heads', ['tensor']],
56+
['kv_heads', ['tensor']],
57+
['embed', ['fsdp', 'expert']],
58+
['embed_no_exp', ['fsdp']],
59+
['q_lora', ['fsdp']],
60+
['kv_lora', ['fsdp']],
61+
['norm', ['tensor']],
62+
['layers', 'stage'],
63+
['cache_heads', ['tensor']],
64+
['exp', 'expert'],
65+
['exp_with_fsdp', 'fsdp'],
66+
['paged_kv_heads', ['tensor']],
67+
['engram_dim', ['tensor']],
68+
]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
16+
# for small-scale training and this rule simplifies the overall configuration.
17+
mesh_axes: ['fsdp']
18+
data_sharding: [['fsdp']]
19+
logical_axis_rules: [
20+
['activation_batch', ['fsdp']],
21+
['activation_batch_no_exp', ['fsdp']],
22+
['activation_embed_and_logits_batch', ['fsdp']],
23+
['activation_embed_and_logits_batch_sequence', ['fsdp']],
24+
['activation_prefill_kv_batch', ['fsdp']],
25+
['activation_kv_batch', ['fsdp']],
26+
['activation_kv_batch_no_exp', ['fsdp']],
27+
['decode_batch', ['fsdp']],
28+
['embed', ['fsdp']],
29+
['embed_no_exp', ['fsdp']],
30+
['q_lora', ['fsdp']],
31+
['kv_lora', ['fsdp']],
32+
['exp_with_fsdp', 'fsdp'],
33+
]

src/maxtext/configs/types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from math import prod
2626
import os
2727
from tempfile import gettempdir
28+
import yaml
2829
from typing import Any, Literal, NewType, Optional
2930

3031
import jax
@@ -781,6 +782,7 @@ class HardwareAndMesh(BaseModel):
781782
description="Strategy for context parallelism ('all_gather' or 'ring').",
782783
)
783784
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
785+
custom_mesh_and_rule: str = Field("", description="Customized mesh and logical rules for granularity.")
784786
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
785787
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
786788
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
@@ -1962,6 +1964,24 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19621964
Computes all derived values and runs all cross-field validations after initial parsing.
19631965
This logic is ported from the legacy pyconfig_deprecated.py system and adapted for Pydantic.
19641966
"""
1967+
if self.custom_mesh_and_rule:
1968+
custom_mesh_path = os.path.join(
1969+
os.path.dirname(os.path.abspath(__file__)),
1970+
"custom_mesh_and_rule",
1971+
f"{self.custom_mesh_and_rule}.yml",
1972+
)
1973+
if os.path.exists(custom_mesh_path):
1974+
with open(custom_mesh_path, "r") as f: # pylint: disable=unspecified-encoding
1975+
custom_mesh_config = yaml.safe_load(f)
1976+
if "mesh_axes" in custom_mesh_config:
1977+
self.mesh_axes = custom_mesh_config["mesh_axes"]
1978+
if "logical_axis_rules" in custom_mesh_config:
1979+
self.logical_axis_rules = custom_mesh_config["logical_axis_rules"]
1980+
if "data_sharding" in custom_mesh_config:
1981+
self.data_sharding = custom_mesh_config["data_sharding"]
1982+
else:
1983+
raise NotImplementedError(f"Custom mesh config file not found at {custom_mesh_path}")
1984+
19651985
# A. SET RUN NAME AND PATHS
19661986
# If run_name is not set, generate one from the JOBSET_NAME environment variable (if available)
19671987
# or create one from the model name and a timestamp.

src/maxtext/layers/attention_op.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ def tpu_flash_attention(
11841184
global_k_layout = self.config.sa_k_layout
11851185
global_v_layout = self.config.sa_v_layout
11861186

1187-
devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"]
1187+
devices_in_data_fsdp = self.mesh.shape.get("data", 1) * self.mesh.shape.get("fsdp", 1)
11881188
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
11891189
"Batch dimension should be shardable among the devices in data and fsdp"
11901190
" axis"
@@ -1284,22 +1284,17 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
12841284
jax.jit,
12851285
static_argnames=[
12861286
"single_head_mask",
1287-
"shard_head_size",
12881287
],
12891288
)
1290-
def wrap_splash_kernel(single_head_mask, shard_head_size=1):
1289+
def wrap_splash_kernel(single_head_mask):
12911290
splash_kernel = tokamax_splash_kernel.make_splash_mha(
12921291
mask=single_head_mask,
12931292
config=sa_config,
12941293
q_seq_shards=cp_size, # axis for sequence sharding,
12951294
)
12961295
return splash_kernel
12971296

1298-
logical_axis_rules_head = np.array(
1299-
[self.mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)[HEAD]]
1300-
)
1301-
shard_head_size = np.prod(logical_axis_rules_head)
1302-
splash_kernel = wrap_splash_kernel(single_head_mask, int(shard_head_size))
1297+
splash_kernel = wrap_splash_kernel(single_head_mask)
13031298
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
13041299
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
13051300
else:
@@ -1331,11 +1326,10 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
13311326
)
13321327
return splash_kernel
13331328

1334-
logical_axis_rules_head = np.array(
1335-
[self.mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)[HEAD]]
1336-
)
1337-
shard_head_size = np.prod(logical_axis_rules_head)
1338-
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
1329+
head_physical_axes = logical_to_mesh_axes((HEAD,), self.mesh)[0]
1330+
head_physical_axes = (head_physical_axes,) if isinstance(head_physical_axes, str) else (head_physical_axes or ())
1331+
shard_head_size = math.prod(self.mesh.shape.get(ax, 1) for ax in head_physical_axes)
1332+
splash_kernel = wrap_splash_kernel(multi_head_mask, shard_head_size)
13391333
named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel)
13401334
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
13411335

src/maxtext/utils/sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def remove_size_one_mesh_axis(spec, mesh):
151151
if s is None or s == P.UNCONSTRAINED:
152152
new_spec.append(s) # type: ignore
153153
elif isinstance(s, tuple):
154-
new_spec.append(tuple(i for i in s if mesh.shape[i] != 1))
154+
new_spec.append(tuple(i for i in s if mesh.shape.get(i, 1) != 1))
155155
else:
156-
new_spec.append(None if mesh.shape[s] == 1 else s) # type: ignore
156+
new_spec.append(None if mesh.shape.get(s, 1) == 1 else s) # type: ignore
157157
return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced)
158158

159159

src/maxtext/utils/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def setup_train_loop(config, recorder, devices=None):
204204
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
205205
rampup_manager = create_rampup_manager(config, checkpoint_manager)
206206
data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager)
207-
context_parallel_size = mesh.shape["context"]
207+
context_parallel_size = mesh.shape.get("context", 1)
208208
# Check if context parallelism is being used with sequence packing
209209
if context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic":
210210
raise ValueError(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for custom mesh and sharding rule configurations in model compilation.
16+
17+
This module verifies that `train_compile.py` correctly processes the
18+
`custom_mesh_and_rule` flag. It ensures that user-defined hardware
19+
meshes and parallelization strategies compile successfully prior to execution.
20+
"""
21+
22+
import unittest
23+
24+
import pytest
25+
26+
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
27+
from tests.utils.test_helpers import get_test_config_path
28+
29+
30+
@pytest.mark.tpu_backend
31+
class CustomMeshAndRuleTest(unittest.TestCase):
32+
"""Tests for custom_mesh functionality in train_compile.py"""
33+
34+
@pytest.mark.cpu_only
35+
def test_pure_fsdp(self):
36+
"""Test compiling with a pure FSDP custom mesh."""
37+
train_compile_main(
38+
(
39+
"",
40+
get_test_config_path(),
41+
"compile_topology=v4-8",
42+
"compile_topology_num_slices=1",
43+
"base_emb_dim=256",
44+
"base_mlp_dim=256",
45+
"base_num_decoder_layers=1",
46+
"custom_mesh_and_rule=pure-fsdp",
47+
)
48+
)
49+
50+
@pytest.mark.cpu_only
51+
def test_ds3_large_pp(self):
52+
"""Test compiling deepseek3-tiny with the pipeline-large-moe custom mesh."""
53+
train_compile_main(
54+
(
55+
"",
56+
get_test_config_path(),
57+
"compile_topology=v5p-32",
58+
"compile_topology_num_slices=1",
59+
"ici_fsdp_transpose_parallelism=2",
60+
"ici_expert_parallelism=2",
61+
"model_name=deepseek3-tiny",
62+
"override_model_config=true",
63+
"base_emb_dim=256",
64+
"base_mlp_dim=256",
65+
"base_num_decoder_layers=4",
66+
"custom_mesh_and_rule=pipeline-large-moe",
67+
)
68+
)

0 commit comments

Comments
 (0)