Skip to content

Commit 12fe4ce

Browse files
Merge pull request #3303 from abhinavgoel95:fix/gate-tpu-power-profiling-events
PiperOrigin-RevId: 878616815
2 parents 7fadf54 + ea55a0d commit 12fe4ce

3 files changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxtext/common/profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, config, offset_step=0):
5151
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
5252

5353
self.profiling_options = jax.profiler.ProfileOptions()
54-
if self.mode == "xplane" and not self.managed_mldiagnostics:
54+
if self.mode == "xplane" and not self.managed_mldiagnostics and config.profile_power_events:
5555
self.profiling_options.advanced_configuration = {
5656
"tpu_power_trace_level": config.xprof_tpu_power_trace_level,
5757
"e2e_enable_fw_throttle_event": config.xprof_e2e_enable_fw_throttle_event,

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ xprof_tpu_power_trace_level: 0
916916
xprof_e2e_enable_fw_throttle_event: False
917917
xprof_e2e_enable_fw_power_level_event: False
918918
xprof_e2e_enable_fw_thermal_event: False
919+
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
919920

920921
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
921922
debug_sharding: False # Prints model weights sharding info

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ class Profiling(BaseModel):
13831383
xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.")
13841384
xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.")
13851385
xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.")
1386+
profile_power_events: bool = Field(False, description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.")
13861387

13871388

13881389
class HloDump(BaseModel):

0 commit comments

Comments
 (0)