|
17 | 17 |
|
18 | 18 | import functools |
19 | 19 | import pickle |
| 20 | +import os |
20 | 21 |
|
21 | 22 | from flax import linen as nn |
22 | 23 | from flax.linen import partitioning as nn_partitioning |
|
40 | 41 | from MaxText import multimodal_utils |
41 | 42 | from MaxText import sharding |
42 | 43 | from MaxText.configs import types |
| 44 | +from MaxText.utils import gcs_utils |
43 | 45 | from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE |
44 | 46 | from MaxText.inference.page_manager import PageState |
45 | 47 | from maxtext.common import checkpointing |
@@ -1234,3 +1236,33 @@ def print_shardings_params(params, params_sharding, mesh): |
1234 | 1236 | shape = jax.typeof(leaf_val) |
1235 | 1237 | pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) |
1236 | 1238 | max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}") |
| 1239 | + |
| 1240 | + |
| 1241 | +def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): |
| 1242 | + """Dump jaxpr to local then upload to GCS.""" |
| 1243 | + if not config.dump_jaxpr: |
| 1244 | + return |
| 1245 | + max_logging.log("Tracing train_step to jaxpr...") |
| 1246 | + |
| 1247 | + # We use the p_train_step (the JIT-decorated function) |
| 1248 | + p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs) |
| 1249 | + |
| 1250 | + local_filename = "train_step.jaxpr" |
| 1251 | + local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename) |
| 1252 | + |
| 1253 | + os.makedirs(config.dump_jaxpr_local_dir, exist_ok=True) |
| 1254 | + |
| 1255 | + # pylint: disable=unspecified-encoding |
| 1256 | + with open(local_path, "w") as f: |
| 1257 | + f.write(str(p_train_jaxpr)) |
| 1258 | + |
| 1259 | + max_logging.log(f"Jaxpr dumped locally to {local_path}") |
| 1260 | + |
| 1261 | + if config.dump_jaxpr_gcs_dir: |
| 1262 | + gcs_utils.upload_dump( |
| 1263 | + config.dump_jaxpr_local_dir, |
| 1264 | + config.dump_jaxpr_gcs_dir, |
| 1265 | + module_name=local_filename, |
| 1266 | + delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging |
| 1267 | + all_host_upload=False, # Only upload from lead host (Host 0) |
| 1268 | + ) |
0 commit comments