Skip to content

Commit cdf4e6b

Browse files
Merge pull request #3091 from AI-Hypercomputer:distill_fixes
PiperOrigin-RevId: 866110131
2 parents 695694b + 6a6203f commit cdf4e6b

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@
4949
from MaxText import optimizers
5050
from MaxText import pyconfig
5151
from MaxText import tokenizer
52+
from MaxText.input_pipeline import input_pipeline_interface
5253
from maxtext.utils import max_logging
5354
from maxtext.utils import maxtext_utils
5455
from maxtext.utils import model_creation_utils
55-
from maxtext.utils import train_utils
5656

5757
# Tunix Imports
5858
from tunix.distillation import distillation_trainer
@@ -97,7 +97,7 @@ def get_distillation_optimizer(config, max_train_steps):
9797
peak_value=config.learning_rate,
9898
warmup_steps=int(config.warmup_steps_fraction * max_train_steps),
9999
decay_steps=max_train_steps,
100-
end_value=config.cosine_learning_rate_final_fraction * config.learning_rate,
100+
end_value=config.learning_rate_final_fraction * config.learning_rate,
101101
)
102102

103103
# 2. Define Factory (Required for inject_hyperparams)
@@ -309,7 +309,7 @@ def _post_process_train_step(self, aux: Dict[str, jax.Array]) -> None:
309309
class MaxTextToTunixIterator:
310310
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.
311311
312-
MaxText's `train_utils.create_data_iterator` yields a dictionary.
312+
MaxText's `input_pipeline_interface.create_data_iterator` yields a dictionary.
313313
Tunix expects an object with specific attributes (input_tokens, etc.).
314314
"""
315315

@@ -503,7 +503,7 @@ def labels_fn(targets, **kwargs):
503503
# We use MaxText's native create_data_iterator which creates both train and eval iterators
504504
# based on the config parameters (dataset_type, eval_interval, etc.)
505505
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
506-
raw_train_iter, raw_eval_iter = train_utils.create_data_iterator(student_config, mesh)
506+
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
507507

508508
train_iter = MaxTextToTunixIterator(raw_train_iter)
509509

tests/unit/train_distill_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_optimizer_factory(self):
123123
config.mu_dtype = "float32"
124124
config.gradient_clipping_threshold = 1.0
125125
config.warmup_steps_fraction = 0.1
126-
config.cosine_learning_rate_final_fraction = 0.1
126+
config.learning_rate_final_fraction = 0.1
127127

128128
# 1. Test Valid Creation
129129
opt = train_distill.get_distillation_optimizer(config, max_train_steps=100)

0 commit comments

Comments
 (0)