|
1 | | -# Copyright 2023–2025 Google LLC |
| 1 | +# Copyright 2023–2026 Google LLC |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -""" |
16 | | -Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step |
17 | | -Generates shaped versions of state and data without ever constructing them, so its possible |
18 | | -to compile with target hardware (e.g. hundreds/thousands of chips), without using the hardware. |
19 | | -This helpfully detects if your configuration would run into memory problems (OOM) on the target hardware, |
20 | | -before having to use the target hardware - you will see the same OOM error message during this compilation |
21 | | -as you would on the target hardware. |
22 | | -""" |
| 15 | +"""Shim for pre-training trainers in `src/maxtext/trainers/pre_train`.""" |
23 | 16 |
|
24 | | -import functools |
25 | | -import os |
26 | | -import pickle |
27 | | -from typing import Sequence |
| 17 | +import sys |
| 18 | +import importlib |
28 | 19 |
|
29 | | -from absl import app |
30 | | -from flax.linen import partitioning as nn_partitioning |
31 | | -import jax |
32 | | -from jax.experimental.serialize_executable import serialize |
33 | | -from jax.experimental.topologies import get_topology_desc |
34 | | -from jax.sharding import AxisType, Mesh |
35 | | -from MaxText import accelerator_to_spec_map |
36 | | -from MaxText import optimizers |
37 | | -from MaxText import pyconfig |
38 | | -from MaxText import sharding |
39 | | -from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode |
40 | | -from maxtext.layers import quantizations |
41 | | -from maxtext.models import models |
42 | | -from maxtext.trainers.diloco import diloco |
43 | | -from maxtext.trainers.pre_train import train |
44 | | -from maxtext.utils import gcs_utils |
45 | | -from maxtext.utils import max_utils |
46 | | -from maxtext.utils import maxtext_utils |
| 20 | +from absl import logging |
47 | 21 |
|
48 | | -# pylint: disable=too-many-positional-arguments |
| 22 | +from maxtext.utils import max_logging |
49 | 23 |
|
50 | | -Transformer = models.transformer_as_linen |
51 | | - |
52 | | - |
53 | | -def validate_config(config): |
54 | | - """Validates the config is is setup correctly to compile, returning a useful error message if not.""" |
55 | | - assert ( |
56 | | - config.compile_topology != "" |
57 | | - ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" |
58 | | - assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" |
59 | | - |
60 | | - |
61 | | -def get_topology_mesh(config): |
62 | | - """Get the target hardware devices, and create configured mesh with them""" |
63 | | - target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) |
64 | | - if target_hardware.platform == "gpu": |
65 | | - # Disable sharded autotuning. This is an optimization to distribute |
66 | | - # autotuning across the fleet, but can cause hangs with AoT compilation. |
67 | | - os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" |
68 | | - jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) |
69 | | - topology_devices = jax.devices() |
70 | | - else: |
71 | | - topology_devices = get_topology_desc( |
72 | | - platform=target_hardware.platform, |
73 | | - topology_name=target_hardware.topology_name, |
74 | | - chip_config_name=target_hardware.chip_config_name, |
75 | | - chips_per_host_bounds=target_hardware.chips_per_host_bounds, |
76 | | - num_slices=config.compile_topology_num_slices, |
77 | | - wrap=target_hardware.wrap, |
78 | | - ).devices |
79 | | - if config.shard_mode == ShardMode.EXPLICIT: |
80 | | - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) |
81 | | - topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) |
82 | | - mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto |
83 | | - topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) |
84 | | - return topology_mesh |
85 | | - |
86 | | - |
87 | | -def get_shaped_inputs(topology_mesh, config): |
88 | | - """Get shaped abstractions of inputs to train_step: state, batch and rng""" |
89 | | - # Construct the model and optimizer to get shaped versions of the state |
90 | | - quant = quantizations.configure_quantization(config) |
91 | | - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) |
92 | | - # The learning_rate_schedule is baked into the compiled object. |
93 | | - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) |
94 | | - # pass in model for muon |
95 | | - tx = optimizers.get_optimizer(config, learning_rate_schedule, model) |
96 | | - |
97 | | - # Shaped RNG keys |
98 | | - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) |
99 | | - shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) |
100 | | - |
101 | | - # Shaped state |
102 | | - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( |
103 | | - model, tx, config, example_rng, topology_mesh |
104 | | - ) |
105 | | - |
106 | | - # unsharded logical annotations |
107 | | - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) |
108 | | - |
109 | | - # Shaped batch |
110 | | - shaped_batch = maxtext_utils.get_shaped_batch(config) |
111 | | - |
112 | | - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) |
113 | | - shaped_train_kwargs = {} |
114 | | - return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model |
115 | | - |
116 | | - |
117 | | -def jit_and_compile( |
118 | | - func, |
119 | | - func_input_args, |
120 | | - func_input_kwargs, |
121 | | - mesh, |
122 | | - in_shardings, |
123 | | - out_shardings, |
124 | | - static_argnums, |
125 | | - donate_argnums, |
126 | | - config, |
127 | | - logical_axis_rules, |
128 | | -): |
129 | | - """Jit, lower, and compile func.""" |
130 | | - with jax.set_mesh(mesh), logical_axis_rules: |
131 | | - jitted = jax.jit( |
132 | | - func, |
133 | | - in_shardings=in_shardings, |
134 | | - out_shardings=out_shardings, |
135 | | - static_argnums=static_argnums, |
136 | | - donate_argnums=donate_argnums, |
137 | | - ) |
138 | | - maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) |
139 | | - lowered = jitted.lower(*func_input_args, **func_input_kwargs) |
140 | | - compiled = lowered.compile() |
141 | | - return compiled |
142 | | - |
143 | | - |
144 | | -def save_compiled(compiled, save_name): |
145 | | - """Serialize and save the compiled function.""" |
146 | | - serialized, _, _ = serialize(compiled) |
147 | | - with open(save_name, "wb") as f: |
148 | | - pickle.dump(serialized, f) |
149 | | - |
150 | | - |
151 | | -def is_oom(argv: Sequence[str]) -> bool: |
152 | | - """Function returns a boolean indicating whether OOM happens""" |
153 | | - # Parse and validate configuration |
154 | | - config = pyconfig.initialize(argv) |
155 | | - validate_config(config) |
156 | | - |
157 | | - # Create target mesh |
158 | | - topology_mesh = get_topology_mesh(config) |
159 | | - |
160 | | - # Print system information after building the compile topology to avoid |
161 | | - # prematurely initializing the backend. |
162 | | - max_utils.print_system_information() |
163 | | - |
164 | | - # Get shaped inputs |
165 | | - ( |
166 | | - shaped_train_args, |
167 | | - shaped_train_kwargs, |
168 | | - state_mesh_shardings, |
169 | | - _, |
170 | | - model, |
171 | | - ) = get_shaped_inputs(topology_mesh, config) |
172 | | - |
173 | | - # Get data sharding |
174 | | - data_sharding = sharding.get_input_data_sharding(config, topology_mesh) |
175 | | - |
176 | | - # Get function to compile and shardings |
177 | | - func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( |
178 | | - maxtext_utils.get_functional_train_with_signature( |
179 | | - train.train_step, data_sharding, state_mesh_shardings, model, config |
180 | | - ) |
181 | | - ) |
| 24 | +OLD_MODULE_PATH = "MaxText.train_comile" |
| 25 | +NEW_MODULE_PATH = "maxtext.trainers.pre_train.train_compile" |
182 | 26 |
|
| 27 | +if __name__ == "__main__": |
183 | 28 | try: |
184 | | - _ = jit_and_compile( |
185 | | - func_to_compile, |
186 | | - shaped_train_args, |
187 | | - shaped_train_kwargs, |
188 | | - topology_mesh, |
189 | | - in_shard, |
190 | | - out_shard, |
191 | | - static_argnums, |
192 | | - donate_argnums, |
193 | | - config, |
194 | | - nn_partitioning.axis_rules(config.logical_axis_rules), |
195 | | - ) |
196 | | - return False |
197 | | - except Exception as e: |
198 | | - # return true if OOM error happens |
199 | | - # OOM error looks like |
200 | | - # jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ... |
201 | | - # jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ... |
202 | | - message = str(e).lower() |
203 | | - if "resource_exhausted" in message or "hbm" in message: |
204 | | - return True |
| 29 | + logging.set_verbosity(logging.INFO) |
| 30 | + _new_module = importlib.import_module(NEW_MODULE_PATH) |
| 31 | + if hasattr(_new_module, "main"): |
| 32 | + max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n") |
| 33 | + _new_module.main(sys.argv) |
| 34 | + except ImportError as e: |
| 35 | + max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n") |
205 | 36 | raise e |
206 | | - |
207 | | - |
208 | | -def main(argv: Sequence[str]) -> None: |
209 | | - jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
210 | | - os.environ["LIBTPU_INIT_ARGS"] = ( |
211 | | - os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" |
212 | | - ) |
213 | | - print("Starting train_compile.py...", flush=True) |
214 | | - |
215 | | - # Parse and validate configuration |
216 | | - config = pyconfig.initialize(argv) |
217 | | - validate_config(config) |
218 | | - |
219 | | - # Create target mesh |
220 | | - topology_mesh = get_topology_mesh(config) |
221 | | - |
222 | | - # Print system information after building the compile topology to avoid |
223 | | - # prematurely initializing the backend. |
224 | | - max_utils.print_system_information() |
225 | | - |
226 | | - # Get shaped inputs |
227 | | - ( |
228 | | - shaped_train_args, |
229 | | - shaped_train_kwargs, |
230 | | - state_mesh_shardings, |
231 | | - logical_annotations, |
232 | | - model, |
233 | | - ) = get_shaped_inputs(topology_mesh, config) |
234 | | - |
235 | | - # Get data sharding |
236 | | - data_sharding = sharding.get_input_data_sharding(config, topology_mesh) |
237 | | - if config.enable_diloco: |
238 | | - # Build abstract DiLoCo state and shardings for AOT compilation |
239 | | - abstract_state = shaped_train_args[0] |
240 | | - diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( |
241 | | - config, abstract_state, state_mesh_shardings, topology_mesh |
242 | | - ) |
243 | | - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) |
244 | | - |
245 | | - # Wrap train_step with diloco |
246 | | - train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) |
247 | | - train_step_fn = diloco.build_diloco_train_step(config, train_step_partial) |
248 | | - |
249 | | - # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng) |
250 | | - func_to_compile = train_step_fn |
251 | | - func_to_compile.__name__ = "train_step" |
252 | | - in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng |
253 | | - out_shard = (state_mesh_shardings, None) # State, metrics |
254 | | - static_argnums = () |
255 | | - donate_argnums = 0 |
256 | | - else: |
257 | | - # Get function to compile and shardings |
258 | | - func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( |
259 | | - maxtext_utils.get_functional_train_with_signature( |
260 | | - train.train_step, data_sharding, state_mesh_shardings, model, config |
261 | | - ) |
262 | | - ) |
263 | | - |
264 | | - # print weights sharding info under debug sharding mode |
265 | | - if config.debug_sharding: |
266 | | - max_utils.print_non_trivial_mesh_axis(topology_mesh) |
267 | | - maxtext_utils.print_shardings_params( |
268 | | - shaped_train_args[0].params, |
269 | | - state_mesh_shardings.params, |
270 | | - topology_mesh, |
271 | | - logical_annotations.params, |
272 | | - ) |
273 | | - |
274 | | - # Compile |
275 | | - print("Jitting and compiling train step...", flush=True) |
276 | | - compiled = jit_and_compile( |
277 | | - func_to_compile, |
278 | | - shaped_train_args, |
279 | | - shaped_train_kwargs, |
280 | | - topology_mesh, |
281 | | - in_shard, |
282 | | - out_shard, |
283 | | - static_argnums, |
284 | | - donate_argnums, |
285 | | - config, |
286 | | - nn_partitioning.axis_rules(config.logical_axis_rules), |
287 | | - ) |
288 | | - print("Jitting and compilation complete!", flush=True) |
289 | | - |
290 | | - # Serialize and save the compiled object |
291 | | - if config.compiled_trainstep_file != "": |
292 | | - print("Saving compiled object...") |
293 | | - save_compiled(compiled, config.compiled_trainstep_file) |
294 | | - print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") |
295 | | - print("Finished train_compile.py successfully!", flush=True) |
296 | | - print(f"Cost analysis: {compiled.cost_analysis()}") |
297 | | - print(f"Memory analysis: {compiled.memory_analysis()}") |
298 | | - |
299 | | - # Dump HLO if requested |
300 | | - if config.dump_hlo: |
301 | | - gcs_utils.upload_dump( |
302 | | - config.dump_hlo_local_dir, |
303 | | - config.dump_hlo_gcs_dir, |
304 | | - module_name=config.dump_hlo_module_name, |
305 | | - delete_local_after=config.dump_hlo_delete_local_after, |
306 | | - all_host_upload=config.dump_hlo_upload_all, |
307 | | - ) |
308 | | - |
309 | | - |
310 | | -if __name__ == "__main__": |
311 | | - app.run(main) |
0 commit comments