Skip to content

Commit 576de45

Browse files
Merge pull request #2998 from AI-Hypercomputer:chengnuojin-dump-jaxpr
PiperOrigin-RevId: 861414981
2 parents d60808f + 9124e5c commit 576de45

8 files changed

Lines changed: 277 additions & 177 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ profile_periodically_period: -1 # If set to a positive integer, profile every pr
703703
managed_mldiagnostics: False # Whether to enable the managed diagnostics
704704
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
705705

706-
# Dump HLO options
706+
# Dump HLO and jaxpr options
707707
dump_hlo: False
708708
dump_step: -1 # Dump modules at the given step if set to a positive integer.
709709
dump_hlo_local_dir: "/tmp/xla_dump/"
@@ -715,6 +715,10 @@ dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_d
715715
dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0
716716
# All hosts should have identical HLO for SPMD programs, however we have encountered some bugs
717717
# where this is not the case and it is helpful to compare HLO across hosts.
718+
dump_jaxpr: False
719+
dump_jaxpr_local_dir: "/tmp/jaxpr_dump/"
720+
dump_jaxpr_delete_local_after: True
721+
dump_jaxpr_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/jaxpr_dump
718722

719723
# When dropout is false the model is a deterministic function of the
720724
# data_shuffle_seed and init_weights_seed (i.e. reproducible losses)

src/MaxText/configs/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,13 @@ class HloDump(BaseModel):
13021302
dump_hlo_local_module_name: str = Field("jit_train_step", description="Filter modules to save locally by this name.")
13031303
dump_hlo_xla_flags: str = Field("", description="Pass custom XLA flags for HLO dumping.")
13041304
dump_hlo_upload_all: bool = Field(False, description="Upload HLO from all hosts.")
1305+
dump_jaxpr: bool = Field(False, description="Enable jaxpr dumping.")
1306+
dump_jaxpr_local_dir: PathStr = Field(
1307+
os.path.join(gettempdir(), "jaxpr_dump", ""),
1308+
description="Local directory to dump jaxpr.",
1309+
)
1310+
dump_jaxpr_delete_local_after: bool = Field(True, description="Delete local jaxpr dump after uploading to GCS.")
1311+
dump_jaxpr_gcs_dir: PathStr = Field("", description="GCS directory to upload jaxpr dumps.")
13051312

13061313

13071314
class StackTrace(BaseModel):
@@ -1877,6 +1884,10 @@ def validate_and_set_hlo_dump_defaults():
18771884
self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump")
18781885
else:
18791886
self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir)
1887+
if not self.dump_jaxpr_gcs_dir:
1888+
self.dump_jaxpr_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "jaxpr_dump")
1889+
else:
1890+
self.dump_jaxpr_gcs_dir = gcs_utils.add_trailing_slash(self.dump_jaxpr_gcs_dir)
18801891
if not os.environ.get("XLA_FLAGS"):
18811892
os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags
18821893

src/MaxText/maxtext_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import functools
1919
import pickle
20+
import os
2021

2122
from flax import linen as nn
2223
from flax.linen import partitioning as nn_partitioning
@@ -40,6 +41,7 @@
4041
from MaxText import multimodal_utils
4142
from MaxText import sharding
4243
from MaxText.configs import types
44+
from MaxText.utils import gcs_utils
4345
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4446
from MaxText.inference.page_manager import PageState
4547
from maxtext.common import checkpointing
@@ -1234,3 +1236,33 @@ def print_shardings_params(params, params_sharding, mesh):
12341236
shape = jax.typeof(leaf_val)
12351237
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
12361238
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1239+
1240+
1241+
def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
1242+
"""Dump jaxpr to local then upload to GCS."""
1243+
if not config.dump_jaxpr:
1244+
return
1245+
max_logging.log("Tracing train_step to jaxpr...")
1246+
1247+
# We use the p_train_step (the JIT-decorated function)
1248+
p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs)
1249+
1250+
local_filename = "train_step.jaxpr"
1251+
local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename)
1252+
1253+
os.makedirs(config.dump_jaxpr_local_dir, exist_ok=True)
1254+
1255+
# pylint: disable=unspecified-encoding
1256+
with open(local_path, "w") as f:
1257+
f.write(str(p_train_jaxpr))
1258+
1259+
max_logging.log(f"Jaxpr dumped locally to {local_path}")
1260+
1261+
if config.dump_jaxpr_gcs_dir:
1262+
gcs_utils.upload_dump(
1263+
config.dump_jaxpr_local_dir,
1264+
config.dump_jaxpr_gcs_dir,
1265+
module_name=local_filename,
1266+
delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging
1267+
all_host_upload=False, # Only upload from lead host (Host 0)
1268+
)

src/MaxText/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ def train_loop(config, recorder, state=None):
426426
shaped_batch = maxtext_utils.get_shaped_batch(config)
427427
if config.shard_optimizer_over_data:
428428
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
429+
maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng))
429430
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
430431
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
431432
compiled_stats = compiled.memory_analysis()

src/MaxText/train_compile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def jit_and_compile(
121121
out_shardings,
122122
static_argnums,
123123
donate_argnums,
124+
config,
124125
logical_axis_rules,
125126
):
126127
"""Jit, lower, and compile func."""
@@ -132,6 +133,7 @@ def jit_and_compile(
132133
static_argnums=static_argnums,
133134
donate_argnums=donate_argnums,
134135
)
136+
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
135137
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
136138
compiled = lowered.compile()
137139
return compiled
@@ -180,6 +182,7 @@ def is_oom(argv: Sequence[str]) -> bool:
180182
out_shard,
181183
static_argnums,
182184
donate_argnums,
185+
config,
183186
nn_partitioning.axis_rules(config.logical_axis_rules),
184187
)
185188
return False
@@ -241,6 +244,7 @@ def main(argv: Sequence[str]) -> None:
241244
out_shard,
242245
static_argnums,
243246
donate_argnums,
247+
config,
244248
nn_partitioning.axis_rules(config.logical_axis_rules),
245249
)
246250
print("Jitting and compilation complete!", flush=True)

src/MaxText/utils/gcs_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def upload_dump(local_dir, target_dir, module_name=None, delete_local_after=True
7979
hostname = socket.gethostname() # Alternatively can use jax.process_id()
8080
prefix_name = os.path.join(prefix_name, hostname)
8181
target_dir = os.path.join(target_dir, hostname)
82-
max_logging.log(f"Uploading HLO Dump to {target_dir}...")
82+
max_logging.log(f"Uploading Dump to {target_dir}...")
8383
for root, _, files in os.walk(local_dir):
8484
for file in files:
8585
if module_name and module_name not in file:
@@ -91,7 +91,7 @@ def upload_dump(local_dir, target_dir, module_name=None, delete_local_after=True
9191
blob_name = os.path.join(prefix_name, relative_path)
9292
blob = bucket.blob(blob_name)
9393
blob.upload_from_filename(local_path)
94-
max_logging.log(f"HLO Dump Uploaded to {target_dir}!")
94+
max_logging.log(f"Dump Uploaded to {target_dir}!")
9595
if delete_local_after:
9696
shutil.rmtree(local_dir)
9797

tests/integration/aot_hlo_identical_test.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

0 commit comments

Comments
 (0)