Skip to content

Commit c24d321

Browse files
Merge pull request #3314 from AI-Hypercomputer:chengnuojin-move-ga
PiperOrigin-RevId: 878766160
2 parents 4af13bf + f5e4953 commit c24d321

3 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,11 @@ class Profiling(BaseModel):
13831383
xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.")
13841384
xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.")
13851385
xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.")
1386-
profile_power_events: bool = Field(False, description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.")
1386+
profile_power_events: bool = Field(
1387+
False,
1388+
description="Enable TPU-specific power/thermal profiling events."
1389+
" Defaults to False to avoid breaking GPU xplane tracing.",
1390+
)
13871391

13881392

13891393
class HloDump(BaseModel):

src/maxtext/trainers/pre_train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled
5959
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
6060
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
61-
from maxtext.optimizers.gradient_accumulation import gradient_accumulation_loss_and_grad
6261
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
6362
from maxtext.utils import exceptions
6463
from maxtext.utils import gcs_utils
@@ -68,6 +67,7 @@
6867
from maxtext.utils import qk_clip_utils
6968
from maxtext.utils import sharding
7069
from maxtext.utils import train_utils
70+
from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad
7171
from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss
7272

7373
_diag_modules = _cloud_diag()
File renamed without changes.

0 commit comments

Comments
 (0)