|
49 | 49 | from MaxText import optimizers |
50 | 50 | from MaxText import pyconfig |
51 | 51 | from MaxText import tokenizer |
| 52 | +from MaxText.input_pipeline import input_pipeline_interface |
52 | 53 | from maxtext.utils import max_logging |
53 | 54 | from maxtext.utils import maxtext_utils |
54 | 55 | from maxtext.utils import model_creation_utils |
55 | | -from maxtext.utils import train_utils |
56 | 56 |
|
57 | 57 | # Tunix Imports |
58 | 58 | from tunix.distillation import distillation_trainer |
@@ -97,7 +97,7 @@ def get_distillation_optimizer(config, max_train_steps): |
97 | 97 | peak_value=config.learning_rate, |
98 | 98 | warmup_steps=int(config.warmup_steps_fraction * max_train_steps), |
99 | 99 | decay_steps=max_train_steps, |
100 | | - end_value=config.cosine_learning_rate_final_fraction * config.learning_rate, |
| 100 | + end_value=config.learning_rate_final_fraction * config.learning_rate, |
101 | 101 | ) |
102 | 102 |
|
103 | 103 | # 2. Define Factory (Required for inject_hyperparams) |
@@ -309,7 +309,7 @@ def _post_process_train_step(self, aux: Dict[str, jax.Array]) -> None: |
309 | 309 | class MaxTextToTunixIterator: |
310 | 310 | """Adapts the raw dictionary output of MaxText's data loader to Tunix objects. |
311 | 311 |
|
312 | | - MaxText's `train_utils.create_data_iterator` yields a dictionary. |
| 312 | + MaxText's `input_pipeline_interface.create_data_iterator` yields a dictionary. |
313 | 313 | Tunix expects an object with specific attributes (input_tokens, etc.). |
314 | 314 | """ |
315 | 315 |
|
@@ -503,7 +503,7 @@ def labels_fn(targets, **kwargs): |
503 | 503 | # We use MaxText's native create_data_iterator which creates both train and eval iterators |
504 | 504 | # based on the config parameters (dataset_type, eval_interval, etc.) |
505 | 505 | max_logging.log("Initializing Data Iterators via MaxText pipeline...") |
506 | | - raw_train_iter, raw_eval_iter = train_utils.create_data_iterator(student_config, mesh) |
| 506 | + raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) |
507 | 507 |
|
508 | 508 | train_iter = MaxTextToTunixIterator(raw_train_iter) |
509 | 509 |
|
|
0 commit comments