Skip to content

Commit 27e5660

Browse files
Merge pull request #3204 from AI-Hypercomputer:bvandermoon-repo-restructure
PiperOrigin-RevId: 873129838
2 parents ec712bf + 9dc2704 commit 27e5660

12 files changed

Lines changed: 339 additions & 303 deletions

File tree

docs/guides/monitoring_and_debugging/features_and_diagnostics.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ After installing the dependencies listed above, you are ready to compile ahead o
5656

5757
```sh
5858
# Run the below on a single machine, e.g. a CPU
59-
python3 MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \
59+
python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \
6060
global_parameter_scale=16 per_device_batch_size=4
6161
```
6262

@@ -71,7 +71,7 @@ Here is an example that saves then loads the compiled `train_step`, starting wit
7171
```sh
7272
# Run the below on a single machine, e.g. a CPU
7373
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
74-
python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 \
74+
python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 \
7575
compile_topology_num_slices=2 \
7676
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \
7777
per_device_batch_size=4 steps=10000 learning_rate=1e-3
@@ -109,7 +109,7 @@ This example illustrates the flags to use for a multihost GPU compilation target
109109
```sh
110110
# Run the below on a single A3 machine
111111
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
112-
python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=a3 \
112+
python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=a3 \
113113
compile_topology_num_slices=4 \
114114
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \
115115
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3

src/MaxText/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import jax
4141

4242
from MaxText import pyconfig
43-
from MaxText import train_compile
43+
from maxtext.trainers.pre_train import train_compile
4444

4545

4646
def generate_priority_list(config, provided_tensor_names):

src/MaxText/train_compile.py

Lines changed: 16 additions & 291 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,300 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""
16-
Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step
17-
Generates shaped versions of state and data without ever constructing them, so its possible
18-
to compile with target hardware (e.g. hundreds/thousands of chips), without using the hardware.
19-
This helpfully detects if your configuration would run into memory problems (OOM) on the target hardware,
20-
before having to use the target hardware - you will see the same OOM error message during this compilation
21-
as you would on the target hardware.
22-
"""
15+
"""Shim for pre-training trainers in `src/maxtext/trainers/pre_train`."""
2316

24-
import functools
25-
import os
26-
import pickle
27-
from typing import Sequence
17+
import sys
18+
import importlib
2819

29-
from absl import app
30-
from flax.linen import partitioning as nn_partitioning
31-
import jax
32-
from jax.experimental.serialize_executable import serialize
33-
from jax.experimental.topologies import get_topology_desc
34-
from jax.sharding import AxisType, Mesh
35-
from MaxText import accelerator_to_spec_map
36-
from MaxText import optimizers
37-
from MaxText import pyconfig
38-
from MaxText import sharding
39-
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
40-
from maxtext.layers import quantizations
41-
from maxtext.models import models
42-
from maxtext.trainers.diloco import diloco
43-
from maxtext.trainers.pre_train import train
44-
from maxtext.utils import gcs_utils
45-
from maxtext.utils import max_utils
46-
from maxtext.utils import maxtext_utils
20+
from absl import logging
4721

48-
# pylint: disable=too-many-positional-arguments
22+
from maxtext.utils import max_logging
4923

50-
Transformer = models.transformer_as_linen
51-
52-
53-
def validate_config(config):
54-
"""Validates the config is is setup correctly to compile, returning a useful error message if not."""
55-
assert (
56-
config.compile_topology != ""
57-
), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256"
58-
assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer"
59-
60-
61-
def get_topology_mesh(config):
62-
"""Get the target hardware devices, and create configured mesh with them"""
63-
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
64-
if target_hardware.platform == "gpu":
65-
# Disable sharded autotuning. This is an optimization to distribute
66-
# autotuning across the fleet, but can cause hangs with AoT compilation.
67-
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
68-
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
69-
topology_devices = jax.devices()
70-
else:
71-
topology_devices = get_topology_desc(
72-
platform=target_hardware.platform,
73-
topology_name=target_hardware.topology_name,
74-
chip_config_name=target_hardware.chip_config_name,
75-
chips_per_host_bounds=target_hardware.chips_per_host_bounds,
76-
num_slices=config.compile_topology_num_slices,
77-
wrap=target_hardware.wrap,
78-
).devices
79-
if config.shard_mode == ShardMode.EXPLICIT:
80-
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
81-
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
82-
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
83-
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
84-
return topology_mesh
85-
86-
87-
def get_shaped_inputs(topology_mesh, config):
88-
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
89-
# Construct the model and optimizer to get shaped versions of the state
90-
quant = quantizations.configure_quantization(config)
91-
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
92-
# The learning_rate_schedule is baked into the compiled object.
93-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
94-
# pass in model for muon
95-
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
96-
97-
# Shaped RNG keys
98-
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)
99-
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)
100-
101-
# Shaped state
102-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
103-
model, tx, config, example_rng, topology_mesh
104-
)
105-
106-
# unsharded logical annotations
107-
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)
108-
109-
# Shaped batch
110-
shaped_batch = maxtext_utils.get_shaped_batch(config)
111-
112-
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
113-
shaped_train_kwargs = {}
114-
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model
115-
116-
117-
def jit_and_compile(
118-
func,
119-
func_input_args,
120-
func_input_kwargs,
121-
mesh,
122-
in_shardings,
123-
out_shardings,
124-
static_argnums,
125-
donate_argnums,
126-
config,
127-
logical_axis_rules,
128-
):
129-
"""Jit, lower, and compile func."""
130-
with jax.set_mesh(mesh), logical_axis_rules:
131-
jitted = jax.jit(
132-
func,
133-
in_shardings=in_shardings,
134-
out_shardings=out_shardings,
135-
static_argnums=static_argnums,
136-
donate_argnums=donate_argnums,
137-
)
138-
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
139-
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
140-
compiled = lowered.compile()
141-
return compiled
142-
143-
144-
def save_compiled(compiled, save_name):
145-
"""Serialize and save the compiled function."""
146-
serialized, _, _ = serialize(compiled)
147-
with open(save_name, "wb") as f:
148-
pickle.dump(serialized, f)
149-
150-
151-
def is_oom(argv: Sequence[str]) -> bool:
152-
"""Function returns a boolean indicating whether OOM happens"""
153-
# Parse and validate configuration
154-
config = pyconfig.initialize(argv)
155-
validate_config(config)
156-
157-
# Create target mesh
158-
topology_mesh = get_topology_mesh(config)
159-
160-
# Print system information after building the compile topology to avoid
161-
# prematurely initializing the backend.
162-
max_utils.print_system_information()
163-
164-
# Get shaped inputs
165-
(
166-
shaped_train_args,
167-
shaped_train_kwargs,
168-
state_mesh_shardings,
169-
_,
170-
model,
171-
) = get_shaped_inputs(topology_mesh, config)
172-
173-
# Get data sharding
174-
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
175-
176-
# Get function to compile and shardings
177-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
178-
maxtext_utils.get_functional_train_with_signature(
179-
train.train_step, data_sharding, state_mesh_shardings, model, config
180-
)
181-
)
24+
OLD_MODULE_PATH = "MaxText.train_comile"
25+
NEW_MODULE_PATH = "maxtext.trainers.pre_train.train_compile"
18226

27+
if __name__ == "__main__":
18328
try:
184-
_ = jit_and_compile(
185-
func_to_compile,
186-
shaped_train_args,
187-
shaped_train_kwargs,
188-
topology_mesh,
189-
in_shard,
190-
out_shard,
191-
static_argnums,
192-
donate_argnums,
193-
config,
194-
nn_partitioning.axis_rules(config.logical_axis_rules),
195-
)
196-
return False
197-
except Exception as e:
198-
# return true if OOM error happens
199-
# OOM error looks like
200-
# jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ...
201-
# jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ...
202-
message = str(e).lower()
203-
if "resource_exhausted" in message or "hbm" in message:
204-
return True
29+
logging.set_verbosity(logging.INFO)
30+
_new_module = importlib.import_module(NEW_MODULE_PATH)
31+
if hasattr(_new_module, "main"):
32+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
33+
_new_module.main(sys.argv)
34+
except ImportError as e:
35+
max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n")
20536
raise e
206-
207-
208-
def main(argv: Sequence[str]) -> None:
209-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
210-
os.environ["LIBTPU_INIT_ARGS"] = (
211-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
212-
)
213-
print("Starting train_compile.py...", flush=True)
214-
215-
# Parse and validate configuration
216-
config = pyconfig.initialize(argv)
217-
validate_config(config)
218-
219-
# Create target mesh
220-
topology_mesh = get_topology_mesh(config)
221-
222-
# Print system information after building the compile topology to avoid
223-
# prematurely initializing the backend.
224-
max_utils.print_system_information()
225-
226-
# Get shaped inputs
227-
(
228-
shaped_train_args,
229-
shaped_train_kwargs,
230-
state_mesh_shardings,
231-
logical_annotations,
232-
model,
233-
) = get_shaped_inputs(topology_mesh, config)
234-
235-
# Get data sharding
236-
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
237-
if config.enable_diloco:
238-
# Build abstract DiLoCo state and shardings for AOT compilation
239-
abstract_state = shaped_train_args[0]
240-
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
241-
config, abstract_state, state_mesh_shardings, topology_mesh
242-
)
243-
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2])
244-
245-
# Wrap train_step with diloco
246-
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None)
247-
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)
248-
249-
# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
250-
func_to_compile = train_step_fn
251-
func_to_compile.__name__ = "train_step"
252-
in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng
253-
out_shard = (state_mesh_shardings, None) # State, metrics
254-
static_argnums = ()
255-
donate_argnums = 0
256-
else:
257-
# Get function to compile and shardings
258-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
259-
maxtext_utils.get_functional_train_with_signature(
260-
train.train_step, data_sharding, state_mesh_shardings, model, config
261-
)
262-
)
263-
264-
# print weights sharding info under debug sharding mode
265-
if config.debug_sharding:
266-
max_utils.print_non_trivial_mesh_axis(topology_mesh)
267-
maxtext_utils.print_shardings_params(
268-
shaped_train_args[0].params,
269-
state_mesh_shardings.params,
270-
topology_mesh,
271-
logical_annotations.params,
272-
)
273-
274-
# Compile
275-
print("Jitting and compiling train step...", flush=True)
276-
compiled = jit_and_compile(
277-
func_to_compile,
278-
shaped_train_args,
279-
shaped_train_kwargs,
280-
topology_mesh,
281-
in_shard,
282-
out_shard,
283-
static_argnums,
284-
donate_argnums,
285-
config,
286-
nn_partitioning.axis_rules(config.logical_axis_rules),
287-
)
288-
print("Jitting and compilation complete!", flush=True)
289-
290-
# Serialize and save the compiled object
291-
if config.compiled_trainstep_file != "":
292-
print("Saving compiled object...")
293-
save_compiled(compiled, config.compiled_trainstep_file)
294-
print(f"Successfully saved compiled object as {config.compiled_trainstep_file}")
295-
print("Finished train_compile.py successfully!", flush=True)
296-
print(f"Cost analysis: {compiled.cost_analysis()}")
297-
print(f"Memory analysis: {compiled.memory_analysis()}")
298-
299-
# Dump HLO if requested
300-
if config.dump_hlo:
301-
gcs_utils.upload_dump(
302-
config.dump_hlo_local_dir,
303-
config.dump_hlo_gcs_dir,
304-
module_name=config.dump_hlo_module_name,
305-
delete_local_after=config.dump_hlo_delete_local_after,
306-
all_host_upload=config.dump_hlo_upload_all,
307-
)
308-
309-
310-
if __name__ == "__main__":
311-
app.run(main)

0 commit comments

Comments
 (0)