|
36 | 36 | from MaxText import accelerator_to_spec_map, max_utils |
37 | 37 | from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode |
38 | 38 | from MaxText.globals import MAXTEXT_ASSETS_ROOT |
| 39 | +from MaxText.utils import gcs_utils |
39 | 40 |
|
40 | 41 | logger = logging.getLogger(__name__) |
41 | 42 |
|
@@ -1787,6 +1788,29 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig": |
1787 | 1788 | if self.final_logits_soft_cap == 0.0: |
1788 | 1789 | self.final_logits_soft_cap = None |
1789 | 1790 |
|
| 1791 | + # This must be invoked before initializing the backend |
| 1792 | + # pylint: disable=access-member-before-definition |
| 1793 | + def validate_and_set_hlo_dump_defaults(): |
| 1794 | + if os.environ.get("XLA_FLAGS") and self.dump_hlo_xla_flags: |
| 1795 | + raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.") |
| 1796 | + if not os.environ.get("XLA_FLAGS") and not self.dump_hlo_xla_flags: |
| 1797 | + self.dump_hlo_xla_flags = f"--xla_dump_to={self.dump_hlo_local_dir} --xla_dump_large_constants" |
| 1798 | + if self.dump_hlo_local_module_name: |
| 1799 | + self.dump_hlo_xla_flags = ( |
| 1800 | + f"{self.dump_hlo_xla_flags} --xla_dump_hlo_module_re={self.dump_hlo_local_module_name}" |
| 1801 | + ) |
| 1802 | + if not self.dump_hlo_gcs_dir: |
| 1803 | + self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump") |
| 1804 | + else: |
| 1805 | + self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir) |
| 1806 | + if not os.environ.get("XLA_FLAGS"): |
| 1807 | + os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags |
| 1808 | + |
| 1809 | + # pylint: enable=access-member-before-definition |
| 1810 | + |
| 1811 | + # Validate and initiate hlo dump related configs |
| 1812 | + validate_and_set_hlo_dump_defaults() |
| 1813 | + |
1790 | 1814 | # D. CALCULATE MODEL DIMENSIONS from global_parameter_scale |
1791 | 1815 | # This allows scaling the model size up or down easily with a single power-of-two factor. |
1792 | 1816 | emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale) |
|
0 commit comments