Skip to content

Commit cfd7786

Browse files
Merge pull request #2977 from AI-Hypercomputer:chengnuojin-fix-hlo
PiperOrigin-RevId: 859148792
2 parents 1137c42 + a2ccc4a commit cfd7786

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

src/MaxText/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from MaxText import accelerator_to_spec_map, max_utils
3737
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
3838
from MaxText.globals import MAXTEXT_ASSETS_ROOT
39+
from MaxText.utils import gcs_utils
3940

4041
logger = logging.getLogger(__name__)
4142

@@ -1821,6 +1822,29 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
18211822
if self.final_logits_soft_cap == 0.0:
18221823
self.final_logits_soft_cap = None
18231824

1825+
# This must be invoked before initializing the backend
1826+
# pylint: disable=access-member-before-definition
1827+
def validate_and_set_hlo_dump_defaults():
1828+
if os.environ.get("XLA_FLAGS") and self.dump_hlo_xla_flags:
1829+
raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.")
1830+
if not os.environ.get("XLA_FLAGS") and not self.dump_hlo_xla_flags:
1831+
self.dump_hlo_xla_flags = f"--xla_dump_to={self.dump_hlo_local_dir} --xla_dump_large_constants"
1832+
if self.dump_hlo_local_module_name:
1833+
self.dump_hlo_xla_flags = (
1834+
f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}"
1835+
)
1836+
if not self.dump_hlo_gcs_dir:
1837+
self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump")
1838+
else:
1839+
self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir)
1840+
if not os.environ.get("XLA_FLAGS"):
1841+
os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags
1842+
1843+
# pylint: enable=access-member-before-definition
1844+
1845+
# Validate and initiate hlo dump related configs
1846+
validate_and_set_hlo_dump_defaults()
1847+
18241848
# D. CALCULATE MODEL DIMENSIONS from global_parameter_scale
18251849
# This allows scaling the model size up or down easily with a single power-of-two factor.
18261850
emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale)

0 commit comments

Comments
 (0)