Skip to content

Commit 843f1f3

Browse files
Enable power collection in xprof MaxText.
This change adds configuration options to enable TPU power tracing and firmware event collection (throttle, power level, thermal) in the XPlane profiler. PiperOrigin-RevId: 867809819
1 parent fa5efd3 commit 843f1f3

4 files changed

Lines changed: 125 additions & 12 deletions

File tree

src/maxtext/common/profiler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ def __init__(self, config, offset_step=0):
5050
if config.managed_mldiagnostics:
5151
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
5252

53+
self.profiling_options = jax.profiler.ProfileOptions()
54+
if self.mode == "xplane" and not self.managed_mldiagnostics:
55+
self.profiling_options.advanced_configuration = {
56+
"tpu_power_trace_level": config.xprof_tpu_power_trace_level,
57+
"e2e_enable_fw_throttle_event": config.xprof_e2e_enable_fw_throttle_event,
58+
"e2e_enable_fw_power_level_event": config.xprof_e2e_enable_fw_power_level_event,
59+
"e2e_enable_fw_thermal_event": config.xprof_e2e_enable_fw_thermal_event,
60+
}
61+
5362
def maybe_activate_profiler(self, step, state):
5463
"""Conditionally activates the profiler based on the current step.
5564
This method checks if the current training step matches the step designated
@@ -87,7 +96,7 @@ def activate(self, blocking_object=None, optional_postfix=""):
8796
return
8897
self.libcudart.cudaProfilerStart()
8998
elif self.mode == "xplane":
90-
jax.profiler.start_trace(self.output_path)
99+
jax.profiler.start_trace(self.output_path, profiler_options=self.profiling_options)
91100

92101
def maybe_deactivate_profiler(self, step, state):
93102
"""Conditionally deactivates the profiler based on the current step.

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,12 @@ prometheus_port: 0
880880
enable_jax_profiler: False
881881
jax_profiler_port: 9999
882882

883+
# TPU power trace level for xprof. 0:POWER_TRACE_NONE, 1:POWER_TRACE_NORMAL, or 2:POWER_TRACE_SPI
884+
xprof_tpu_power_trace_level: 0
885+
xprof_e2e_enable_fw_throttle_event: False
886+
xprof_e2e_enable_fw_power_level_event: False
887+
xprof_e2e_enable_fw_thermal_event: False
888+
883889
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
884890
debug_sharding: False # Prints model weights sharding info
885891

src/maxtext/configs/types.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,36 @@
1616

1717
# pylint: disable=too-many-lines
1818

19-
from enum import Enum
20-
from math import prod
21-
from tempfile import gettempdir
22-
from typing import Any, NewType, Literal, Optional
2319
import datetime
20+
import enum
21+
from enum import Enum
2422
import logging
2523
import math
24+
from math import prod
2625
import os
26+
from tempfile import gettempdir
27+
from typing import Any, Literal, NewType, Optional
2728

2829
import jax
29-
30-
from pydantic.config import ConfigDict
31-
from pydantic.fields import Field
32-
from pydantic.functional_validators import model_validator, field_validator
33-
from pydantic.main import BaseModel
34-
from pydantic.types import PositiveInt, NonNegativeFloat, NonNegativeInt
35-
3630
from MaxText import accelerator_to_spec_map
3731
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
3832
from MaxText.globals import MAXTEXT_ASSETS_ROOT
3933
from maxtext.utils import gcs_utils
4034
from maxtext.utils import max_utils
35+
from pydantic.config import ConfigDict
36+
from pydantic.fields import Field
37+
from pydantic.functional_validators import field_validator, model_validator
38+
from pydantic.main import BaseModel
39+
from pydantic.types import NonNegativeFloat, NonNegativeInt, PositiveInt
40+
41+
42+
class XProfTPUPowerTraceMode(enum.IntEnum): # pylint: disable=invalid-name
43+
"""Enum for XProfTPUPowerTraceMode."""
44+
45+
POWER_TRACE_NONE = 0
46+
POWER_TRACE_NORMAL = 1
47+
POWER_TRACE_SPI = 2
48+
4149

4250
logger = logging.getLogger(__name__)
4351

@@ -1316,6 +1324,16 @@ class Profiling(BaseModel):
13161324
hide_profiler_step_metric: bool = Field(False, description="Whether to enable profiler step metric.")
13171325
enable_jax_profiler: bool = Field(False, description="Enable the JAX live profiler.")
13181326
jax_profiler_port: int = Field(9999, description="Port for the JAX profiler.")
1327+
xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field(
1328+
XProfTPUPowerTraceMode.POWER_TRACE_NONE,
1329+
description=(
1330+
"TPU power trace level. The value should be 0 (POWER_TRACE_NONE), 1"
1331+
" (POWER_TRACE_NORMAL), or 2 (POWER_TRACE_SPI)"
1332+
),
1333+
)
1334+
xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.")
1335+
xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.")
1336+
xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.")
13191337

13201338

13211339
class HloDump(BaseModel):

tests/unit/profiler_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,99 @@
1313
# limitations under the License.
1414

1515
"""Profiler tests."""
16+
import os
1617
import sys
1718
import unittest
19+
from unittest.mock import MagicMock, patch
1820

1921
import pytest
2022

2123
from MaxText import pyconfig
24+
from MaxText.globals import MAXTEXT_PKG_DIR
25+
from maxtext.configs import types
2226
from maxtext.common import profiler
2327
from tests.utils.test_helpers import get_test_config_path
2428

2529

2630
class ProfilerTest(unittest.TestCase):
2731
"""Test for profiler."""
2832

33+
def setUp(self):
34+
super().setUp()
35+
# Mock jax.devices() to be deterministic and avoid runtime initialization errors
36+
self.mock_devices = [MagicMock(slice_index=0) for _ in range(1)]
37+
self.jax_patcher = patch("jax.devices", return_value=self.mock_devices)
38+
self.jax_patcher.start()
39+
self.jax_process_index_patcher = patch("jax.process_index", return_value=0)
40+
self.jax_process_index_patcher.start()
41+
42+
def tearDown(self):
43+
self.jax_patcher.stop()
44+
self.jax_process_index_patcher.stop()
45+
super().tearDown()
46+
47+
@pytest.mark.tpu_only
48+
def test_profiler_options_populated_from_config(self):
49+
"""Verifies that Profiler initializes jax.profiler.ProfileOptions from config."""
50+
config = pyconfig.initialize(
51+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
52+
enable_checkpointing=False,
53+
run_name="test_profiler_options",
54+
base_output_directory="/tmp",
55+
profiler="xplane",
56+
xprof_tpu_power_trace_level=1,
57+
xprof_e2e_enable_fw_throttle_event=True,
58+
xprof_e2e_enable_fw_power_level_event=True,
59+
xprof_e2e_enable_fw_thermal_event=True,
60+
)
61+
62+
with patch("jax.profiler.ProfileOptions") as mock_options_cls:
63+
# Setup mock return value
64+
mock_options_instance = MagicMock()
65+
mock_options_cls.return_value = mock_options_instance
66+
67+
prof = profiler.Profiler(config)
68+
69+
# Check if ProfileOptions was instantiated
70+
mock_options_cls.assert_called_once()
71+
72+
# Verify advanced_configuration was populated
73+
expected_advanced_config = {
74+
"tpu_power_trace_level": types.XProfTPUPowerTraceMode.POWER_TRACE_NORMAL,
75+
"e2e_enable_fw_throttle_event": True,
76+
"e2e_enable_fw_power_level_event": True,
77+
"e2e_enable_fw_thermal_event": True,
78+
}
79+
self.assertEqual(prof.profiling_options.advanced_configuration, expected_advanced_config)
80+
81+
@pytest.mark.tpu_only
82+
@patch("jax.profiler.start_trace")
83+
def test_profiler_activate_passes_options(self, mock_start_trace):
84+
"""Verifies that activate() passes the profiling_options to jax.profiler.start_trace."""
85+
config = pyconfig.initialize(
86+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
87+
enable_checkpointing=False,
88+
run_name="test_profiler_options",
89+
base_output_directory="/tmp",
90+
profiler="xplane",
91+
xprof_tpu_power_trace_level=2,
92+
)
93+
94+
# We need to mock ProfileOptions as well to check identity or value
95+
with patch("jax.profiler.ProfileOptions"):
96+
prof = profiler.Profiler(config)
97+
prof.activate()
98+
99+
# Verify start_trace was called with profiler_options
100+
mock_start_trace.assert_called_once()
101+
_, kwargs = mock_start_trace.call_args
102+
self.assertIn("profiler_options", kwargs)
103+
self.assertEqual(kwargs["profiler_options"], prof.profiling_options)
104+
self.assertEqual(
105+
prof.profiling_options.advanced_configuration["tpu_power_trace_level"],
106+
types.XProfTPUPowerTraceMode.POWER_TRACE_SPI,
107+
)
108+
29109
# These periodic proilfer tests can run on any platform (cpu, gpu or tpu)
30110
@pytest.mark.tpu_only
31111
def test_periodic_profiler_third_period_starts(self):

0 commit comments

Comments
 (0)