|
36 | 36 | from MaxText import accelerator_to_spec_map, max_utils |
37 | 37 | from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode |
38 | 38 | from MaxText.globals import MAXTEXT_ASSETS_ROOT |
| 39 | +from MaxText.utils import gcs_utils |
39 | 40 |
|
40 | 41 | logger = logging.getLogger(__name__) |
41 | 42 |
|
@@ -1821,6 +1822,29 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig": |
1821 | 1822 | if self.final_logits_soft_cap == 0.0: |
1822 | 1823 | self.final_logits_soft_cap = None |
1823 | 1824 |
|
| 1825 | + # This must be invoked before initializing the backend |
| 1826 | + # pylint: disable=access-member-before-definition |
| 1827 | + def validate_and_set_hlo_dump_defaults(): |
| 1828 | + if os.environ.get("XLA_FLAGS") and self.dump_hlo_xla_flags: |
| 1829 | + raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.") |
| 1830 | + if not os.environ.get("XLA_FLAGS") and not self.dump_hlo_xla_flags: |
| 1831 | + self.dump_hlo_xla_flags = f"--xla_dump_to={self.dump_hlo_local_dir} --xla_dump_large_constants" |
| 1832 | + if self.dump_hlo_local_module_name: |
| 1833 | + self.dump_hlo_xla_flags = ( |
| 1834 | + f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}" |
| 1835 | + ) |
| 1836 | + if not self.dump_hlo_gcs_dir: |
| 1837 | + self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump") |
| 1838 | + else: |
| 1839 | + self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir) |
| 1840 | + if not os.environ.get("XLA_FLAGS"): |
| 1841 | + os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags |
| 1842 | + |
| 1843 | + # pylint: enable=access-member-before-definition |
| 1844 | + |
| 1845 | + # Validate and initiate hlo dump related configs |
| 1846 | + validate_and_set_hlo_dump_defaults() |
| 1847 | + |
1824 | 1848 | # D. CALCULATE MODEL DIMENSIONS from global_parameter_scale |
1825 | 1849 | # This allows scaling the model size up or down easily with a single power-of-two factor. |
1826 | 1850 | emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale) |
|
0 commit comments