Skip to content

Commit 18f2eb6

Browse files
committed
Added train script for offline distillation training
1 parent 161f69a commit 18f2eb6

2 files changed

Lines changed: 137 additions & 30 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21+
import os
22+
import pickle
23+
import tensorflow as tf
24+
from array_record.python import array_record_module
25+
2126
from typing import Any, Iterator, Optional, List, Callable
2227

2328
import flax
@@ -63,12 +68,61 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput):
6368
targets_position: jax.Array = None
6469
#: Segment IDs for packed target tokens.
6570
targets_segmentation: jax.Array = None
66-
71+
#: Top-K logits from the teacher model.
72+
top_k_logits: jax.Array = None
73+
top_k_indices: jax.Array = None
6774

6875
# -----------------------------------------------------------------------------
6976
# Data Loading Adapter
7077
# -----------------------------------------------------------------------------
7178

79+
class OfflineArrayRecordIterator:
80+
"""Reads the pre-generated global top-k logits file."""
81+
def __init__(self, data_dir: str, epochs: int = 100):
82+
# Check if the user passed a directory or a direct file path
83+
if tf.io.gfile.isdir(data_dir):
84+
self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record")
85+
else:
86+
self.filepath = data_dir
87+
88+
if not tf.io.gfile.exists(self.filepath):
89+
raise FileNotFoundError(f"Offline distillation file not found: {self.filepath}")
90+
91+
self.reader = array_record_module.ArrayRecordReader(self.filepath)
92+
self.num_records = self.reader.num_records()
93+
self.epochs = epochs
94+
self.current_epoch = 0
95+
self.record_index = 0
96+
97+
def __iter__(self):
98+
return self
99+
100+
def __next__(self):
101+
if self.record_index < self.num_records:
102+
pass
103+
104+
self.current_epoch += 1
105+
if self.current_epoch >= self.epochs:
106+
raise StopIteration
107+
108+
self.record_index = 0
109+
self.reader = array_record_module.ArrayRecordReader(self.filepath)
110+
111+
record = self.reader.read()
112+
self.record_index += 1
113+
data = pickle.loads(record)
114+
115+
# Map the arrays to match MaxText's expected dictionary
116+
batch = {
117+
"inputs": data["tokens"],
118+
"top_k_logits": data["top_k_logits"],
119+
"top_k_indices": data["top_k_indices"],
120+
}
121+
for key in ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]:
122+
if key in data:
123+
batch[key] = data[key]
124+
125+
return batch
72126

73127
class MaxTextToTunixIterator:
74128
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.
@@ -123,6 +177,8 @@ def __next__(self) -> MaxTextTrainingInput:
123177
targets=batch["targets"],
124178
targets_position=targets_position,
125179
targets_segmentation=targets_segmentation,
180+
top_k_logits=batch.get("top_k_logits"),
181+
top_k_indices=batch.get("top_k_indices"),
126182
)
127183

128184

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434
"""
35+
import argparse
36+
import functools
37+
import sys
3538

3639
from typing import Sequence, Callable
3740
from absl import app
@@ -303,6 +306,8 @@ def _prepare_inputs(
303306
targets=input_data.targets,
304307
targets_position=input_data.targets_position,
305308
targets_segmentation=input_data.targets_segmentation,
309+
top_k_logits=input_data.top_k_logits,
310+
top_k_indices=input_data.top_k_indices
306311
)
307312

308313
def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None:
@@ -401,7 +406,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
401406
# -----------------------------------------------------------------------------
402407

403408

404-
def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None:
409+
def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters, is_offline: bool = False, offline_data_dir: str | None = None) -> None:
405410
"""Main distillation training loop.
406411
407412
Orchestrates the loading of both student and teacher models, configures the
@@ -437,9 +442,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
437442
_log_config_details(student_config, "Student")
438443
student_model = get_maxtext_model(student_config, mesh)
439444

440-
max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
441-
_log_config_details(teacher_config, "Teacher")
442-
teacher_model = get_maxtext_model(teacher_config, mesh)
445+
# Skip teacher model loading if offline
446+
if is_offline:
447+
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
448+
teacher_model = None
449+
else:
450+
max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
451+
_log_config_details(teacher_config, "Teacher")
452+
teacher_model = get_maxtext_model(teacher_config, mesh)
453+
teacher_model.eval()
443454

444455
# 3. Define Distillation Strategy
445456
def labels_fn(targets, targets_segmentation=None, **kwargs):
@@ -502,13 +513,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
502513
)
503514

504515
# 5. Data Iterators (Init BEFORE Trainer)
505-
# We use MaxText's native create_data_iterator which creates both train and eval iterators
506-
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
507-
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
516+
if is_offline:
517+
max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...")
518+
raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir)
519+
raw_eval_iter = None
520+
else:
521+
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
522+
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
508523

509-
teacher_model.eval()
510524
student_model.train()
511-
512525
model_bundle = ModelBundle(teacher_model, student_model)
513526

514527
# 6. Initialize Trainer
@@ -526,18 +539,41 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
526539
raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config)
527540

528541
# 8. Configure Input Mapping
529-
trainer = trainer.with_gen_model_input_fn(
530-
lambda batch: {
531-
"input_tokens": batch.input_tokens,
532-
"positions": batch.positions,
533-
"attention_mask": batch.input_mask,
534-
"decoder_segment_ids": batch.decoder_segment_ids,
535-
"targets": batch.targets, # Passed to strategy (labels_fn)
536-
"targets_position": batch.targets_position, # Passed to strategy (labels_fn)
537-
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn)
538-
"cache": None,
539-
}
540-
)
542+
def custom_gen_model_input_fn(batch):
543+
inputs_dict = {
544+
"input_tokens": batch.input_tokens,
545+
"positions": batch.positions,
546+
"attention_mask": batch.input_mask,
547+
"decoder_segment_ids": batch.decoder_segment_ids,
548+
"targets": batch.targets,
549+
"targets_position": batch.targets_position,
550+
"targets_segmentation": batch.targets_segmentation,
551+
"cache": None,
552+
}
553+
554+
# If we are in online mode then we exit
555+
if getattr(batch, "top_k_logits", None) is None:
556+
return inputs_dict
557+
558+
# Scatter the offline arrays into a dense tensor of -10000s
559+
dense_shape = batch.input_tokens.shape + (student_config.vocab_size,)
560+
dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32)
561+
dense_logits = jnp.put_along_axis(
562+
dense_logits,
563+
batch.top_k_indices,
564+
batch.top_k_logits,
565+
axis=-1,
566+
inplace=False
567+
)
568+
569+
# Inject it as teacher_output so the trainer skips the teacher forward pass
570+
inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput(
571+
logits=dense_logits, out_projection_activations=None
572+
)
573+
574+
return inputs_dict
575+
576+
trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)
541577

542578
# 9. Create Iterator Wrappers (Use Utils)
543579
train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter)
@@ -584,14 +620,11 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
584620
max_logging.log("Distillation Complete.")
585621

586622

587-
def main(argv: Sequence[str]) -> None:
623+
def main(argv: Sequence[str], local_args) -> None:
588624
"""Entry point for the script.
589625
590626
Parses configuration, isolates Student and Teacher overrides, and triggers the
591627
training loop.
592-
593-
Args:
594-
argv: List of command-line arguments. Expects [script_name, config_file, ...].
595628
"""
596629
# 1. Parse Global Config to extract Overrides
597630
global_config = pyconfig.initialize(argv)
@@ -602,11 +635,11 @@ def main(argv: Sequence[str]) -> None:
602635
student_config = pyconfig.initialize(argv, **student_overrides)
603636

604637
# 3. Initialize TEACHER Config
605-
# We isolate the Teacher from Student CLI arguments (like pruning params).
638+
# We isolate the Teacher from Student CLI arguments (like pruning params).
606639
teacher_overrides = global_config.teacher_overrides
607640

608641
# Ensure load_parameters_path is set in overrides
609-
if not teacher_overrides.get("load_parameters_path"):
642+
if not local_args.offline_distillation and not teacher_overrides.get("load_parameters_path"):
610643
raise ValueError(
611644
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
612645
"in your config or arguments."
@@ -618,8 +651,26 @@ def main(argv: Sequence[str]) -> None:
618651
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
619652

620653
# 4. Run Training
621-
train_distill(student_config, teacher_config)
654+
train_distill(student_config, teacher_config, local_args.offline_distillation, local_args.offline_data_dir)
622655

623656

624657
if __name__ == "__main__":
625-
app.run(main)
658+
parser = argparse.ArgumentParser()
659+
parser.add_argument(
660+
"--offline_distillation",
661+
action="store_true",
662+
help="Pass this flag to enable offline distillation.",
663+
)
664+
parser.add_argument(
665+
"--offline_data_dir",
666+
type=str,
667+
required=False,
668+
default=None,
669+
help="GCS or local path to the pre-generated ArrayRecord teacher data.",
670+
)
671+
672+
# parse_known_args separates our custom flags from MaxText's standard args
673+
local_arg, remaining_args = parser.parse_known_args()
674+
675+
main_wrapper = functools.partial(main, local_args=local_arg)
676+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)