|
42 | 42 | from cloud_tpu_diagnostics.configuration import diagnostic_configuration |
43 | 43 | from cloud_tpu_diagnostics.configuration import stack_trace_configuration |
44 | 44 |
|
45 | | -from MaxText import checkpointing |
46 | 45 | from MaxText import exceptions |
47 | 46 | from MaxText import max_logging |
48 | 47 | from MaxText import max_utils |
49 | 48 | from MaxText import maxtext_utils |
50 | 49 | from MaxText import train_utils |
51 | | -from MaxText import profiler |
52 | 50 | from MaxText import pyconfig |
53 | 51 | from MaxText import sharding |
54 | 52 | from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss |
55 | 53 | from MaxText.common_types import ShardMode |
56 | 54 | from MaxText.globals import EPS |
57 | | -from MaxText.metric_logger import MetricLogger |
58 | 55 | from MaxText.utils import gcs_utils |
59 | | -from MaxText.utils.goodput_utils import ( |
60 | | - GoodputEvent, |
61 | | - create_goodput_recorder, |
62 | | - maybe_monitor_goodput, |
63 | | - maybe_record_goodput, |
64 | | -) |
65 | | -from MaxText.vertex_tensorboard import VertexTensorboardManager |
66 | 56 | # Placeholder: internal |
67 | 57 |
|
68 | 58 | from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad |
69 | 59 | from MaxText.vocabulary_tiling import vocab_tiling_linen_loss |
70 | 60 | from MaxText.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn |
71 | 61 | from MaxText.train_utils import validate_train_config |
72 | | -from MaxText.metric_logger import record_activation_metrics |
73 | 62 | # pylint: disable=too-many-positional-arguments |
74 | 63 |
|
| 64 | +from maxtext.common import checkpointing, profiler |
| 65 | +from maxtext.common.goodput import ( |
| 66 | + GoodputEvent, |
| 67 | + create_goodput_recorder, |
| 68 | + maybe_monitor_goodput, |
| 69 | + maybe_record_goodput, |
| 70 | +) |
| 71 | +from maxtext.common.metric_logger import MetricLogger, record_activation_metrics |
| 72 | +from maxtext.common.vertex_tensorboard import VertexTensorboardManager |
| 73 | + |
75 | 74 |
|
76 | 75 | def get_first_step(state): |
77 | 76 | return int(state.step) |
|
0 commit comments