Skip to content

Commit a2ccc4a

Browse files
committed
fix dump hlo issue
1 parent d4a259d commit a2ccc4a

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

src/MaxText/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from MaxText import accelerator_to_spec_map, max_utils
3737
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
3838
from MaxText.globals import MAXTEXT_ASSETS_ROOT
39+
from MaxText.utils import gcs_utils
3940

4041
logger = logging.getLogger(__name__)
4142

@@ -1787,6 +1788,29 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
17871788
if self.final_logits_soft_cap == 0.0:
17881789
self.final_logits_soft_cap = None
17891790

1791+
# This must be invoked before initializing the backend
1792+
# pylint: disable=access-member-before-definition
1793+
def validate_and_set_hlo_dump_defaults():
1794+
if os.environ.get("XLA_FLAGS") and self.dump_hlo_xla_flags:
1795+
raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.")
1796+
if not os.environ.get("XLA_FLAGS") and not self.dump_hlo_xla_flags:
1797+
self.dump_hlo_xla_flags = f"--xla_dump_to={self.dump_hlo_local_dir} --xla_dump_large_constants"
1798+
if self.dump_hlo_local_module_name:
1799+
self.dump_hlo_xla_flags = (
1800+
f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}"
1801+
)
1802+
if not self.dump_hlo_gcs_dir:
1803+
self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump")
1804+
else:
1805+
self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir)
1806+
if not os.environ.get("XLA_FLAGS"):
1807+
os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags
1808+
1809+
# pylint: enable=access-member-before-definition
1810+
1811+
# Validate and initiate hlo dump related configs
1812+
validate_and_set_hlo_dump_defaults()
1813+
17901814
# D. CALCULATE MODEL DIMENSIONS from global_parameter_scale
17911815
# This allows scaling the model size up or down easily with a single power-of-two factor.
17921816
emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale)

0 commit comments

Comments
 (0)