Skip to content

Commit d272058

Browse files
gobbleturkGoogle-ML-Automation
authored andcommitted
Add internal functionality for train_compile
PiperOrigin-RevId: 878559828
1 parent 441bc95 commit d272058

3 files changed

Lines changed: 46 additions & 22 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ jax_cache_dir: "~/jax_cache"
413413
# Hardware
414414
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'
415415

416+
# internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc.
417+
internal_compile: False
418+
internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile.
419+
416420
# Parallelism
417421
shard_mode: "auto" # can be either auto or explicit
418422
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']

src/maxtext/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,8 @@ class LayoutAndSharding(BaseModel):
795795
description="Allowed percentage of non-sharded parameters.",
796796
)
797797
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
798+
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
799+
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
798800

799801

800802
class DcnParallelism(BaseModel):
@@ -2064,6 +2066,11 @@ def validate_and_set_hlo_dump_defaults():
20642066
# E. HARDWARE-DEPENDENT CALCULATIONS
20652067
def get_num_target_devices():
20662068
"""Get the number of devices for the target topology, handling AOT compilation and single-controller modes."""
2069+
if self.internal_compile:
2070+
if self.internal_compile_num_devices <= 0:
2071+
raise ValueError("Set internal_compile_num_devices to a positive integer.")
2072+
# User bypassing topology mappings should supply explicit device count
2073+
return self.internal_compile_num_devices
20672074
if self.compile_topology:
20682075
spec = accelerator_to_spec_map.get_system_characteristics(self.compile_topology)
20692076
return int(spec.devices_per_slice * self.compile_topology_num_slices)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,27 @@ def validate_config(config):
6060

6161
def get_topology_mesh(config):
6262
"""Get the target hardware devices, and create configured mesh with them"""
63-
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
64-
if target_hardware.platform == "gpu":
65-
# Disable sharded autotuning. This is an optimization to distribute
66-
# autotuning across the fleet, but can cause hangs with AoT compilation.
67-
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
68-
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
69-
topology_devices = jax.devices()
70-
else:
63+
if config.internal_compile:
7164
topology_devices = get_topology_desc(
72-
platform=target_hardware.platform,
73-
topology_name=target_hardware.topology_name,
74-
chip_config_name=target_hardware.chip_config_name,
75-
chips_per_host_bounds=target_hardware.chips_per_host_bounds,
76-
num_slices=config.compile_topology_num_slices,
77-
wrap=target_hardware.wrap,
65+
platform="tpu", topology_name=config.compile_topology, num_slices=config.compile_topology_num_slices
7866
).devices
67+
else:
68+
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
69+
if target_hardware.platform == "gpu":
70+
# Disable sharded autotuning. This is an optimization to distribute
71+
# autotuning across the fleet, but can cause hangs with AoT compilation.
72+
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
73+
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
74+
topology_devices = jax.devices()
75+
else:
76+
topology_devices = get_topology_desc(
77+
platform=target_hardware.platform,
78+
topology_name=target_hardware.topology_name,
79+
chip_config_name=target_hardware.chip_config_name,
80+
chips_per_host_bounds=target_hardware.chips_per_host_bounds,
81+
num_slices=config.compile_topology_num_slices,
82+
wrap=target_hardware.wrap,
83+
).devices
7984
if config.shard_mode == ShardMode.EXPLICIT:
8085
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
8186
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
@@ -174,10 +179,14 @@ def is_oom(argv: Sequence[str]) -> bool:
174179
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
175180

176181
# Get function to compile and shardings
177-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
178-
maxtext_utils.get_functional_train_with_signature(
179-
train.train_step, data_sharding, state_mesh_shardings, model, config
180-
)
182+
(
183+
func_to_compile,
184+
in_shard,
185+
out_shard,
186+
static_argnums,
187+
donate_argnums,
188+
) = maxtext_utils.get_functional_train_with_signature(
189+
train.train_step, data_sharding, state_mesh_shardings, model, config
181190
)
182191

183192
try:
@@ -255,10 +264,14 @@ def main(argv: Sequence[str]) -> None:
255264
donate_argnums = 0
256265
else:
257266
# Get function to compile and shardings
258-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
259-
maxtext_utils.get_functional_train_with_signature(
260-
train.train_step, data_sharding, state_mesh_shardings, model, config
261-
)
267+
(
268+
func_to_compile,
269+
in_shard,
270+
out_shard,
271+
static_argnums,
272+
donate_argnums,
273+
) = maxtext_utils.get_functional_train_with_signature(
274+
train.train_step, data_sharding, state_mesh_shardings, model, config
262275
)
263276

264277
# print weights sharding info under debug sharding mode

0 commit comments

Comments
 (0)