Skip to content

Commit b723c4e

Browse files
Merge pull request #3382 from AI-Hypercomputer:ajkv/offline-distillation-soft
PiperOrigin-RevId: 886869786
2 parents eb9d12f + 5fb1c6c commit b723c4e

4 files changed

Lines changed: 276 additions & 27 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,11 @@ class Distillation(BaseModel):
10901090
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10911091
)
10921092

1093+
# --- Offline Distillation Field ---
1094+
offline_data_dir: Optional[str] = Field(
1095+
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
1096+
)
1097+
10931098
# --- Loss Params ---
10941099
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
10951100
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21+
import pickle
22+
import tensorflow as tf
23+
from array_record.python import array_record_module
24+
2125
from typing import Any, Iterator, Optional, List, Callable
2226

2327
import flax
@@ -63,13 +67,60 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput):
6367
targets_position: jax.Array = None
6468
#: Segment IDs for packed target tokens.
6569
targets_segmentation: jax.Array = None
70+
#: Top-K logits from the teacher model.
71+
top_k_logits: jax.Array = None
72+
top_k_indices: jax.Array = None
6673

6774

6875
# -----------------------------------------------------------------------------
6976
# Data Loading Adapter
7077
# -----------------------------------------------------------------------------
7178

7279

80+
class OfflineArrayRecordIterator:
81+
"""Reads the pre-generated global top-k logits file."""
82+
83+
def __init__(self, data_dir: str, epochs: int = 100):
84+
self.filepath = data_dir
85+
86+
if not tf.io.gfile.exists(self.filepath):
87+
raise FileNotFoundError(f"Offline distillation file not found: {self.filepath}")
88+
89+
self.reader = array_record_module.ArrayRecordReader(self.filepath)
90+
self.num_records = self.reader.num_records()
91+
self.epochs = epochs
92+
self.current_epoch = 0
93+
self.record_index = 0
94+
95+
def __iter__(self):
96+
return self
97+
98+
def __next__(self):
99+
if self.record_index >= self.num_records:
100+
self.current_epoch += 1
101+
if self.current_epoch >= self.epochs:
102+
raise StopIteration
103+
104+
self.record_index = 0
105+
self.reader = array_record_module.ArrayRecordReader(self.filepath)
106+
107+
record = self.reader.read()
108+
self.record_index += 1
109+
data = pickle.loads(record)
110+
111+
# Map the arrays to match MaxText's expected dictionary
112+
batch = {
113+
"inputs": data["tokens"],
114+
"top_k_logits": data["top_k_logits"],
115+
"top_k_indices": data["top_k_indices"],
116+
}
117+
for key in ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]:
118+
if key in data:
119+
batch[key] = data[key]
120+
121+
return batch
122+
123+
73124
class MaxTextToTunixIterator:
74125
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.
75126
@@ -123,6 +174,8 @@ def __next__(self) -> MaxTextTrainingInput:
123174
targets=batch["targets"],
124175
targets_position=targets_position,
125176
targets_segmentation=targets_segmentation,
177+
top_k_logits=batch.get("top_k_logits"),
178+
top_k_indices=batch.get("top_k_indices"),
126179
)
127180

128181

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434
"""
35-
3635
from typing import Sequence, Callable
3736
from absl import app
3837
from flax import nnx
@@ -303,6 +302,8 @@ def _prepare_inputs(
303302
targets=input_data.targets,
304303
targets_position=input_data.targets_position,
305304
targets_segmentation=input_data.targets_segmentation,
305+
top_k_logits=input_data.top_k_logits,
306+
top_k_indices=input_data.top_k_indices,
306307
)
307308

308309
def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None:
@@ -401,7 +402,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
401402
# -----------------------------------------------------------------------------
402403

403404

404-
def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None:
405+
def train_distill(
406+
student_config: pyconfig.HyperParameters,
407+
teacher_config: pyconfig.HyperParameters,
408+
is_offline: bool = False,
409+
offline_data_dir: str | None = None,
410+
) -> None:
405411
"""Main distillation training loop.
406412
407413
Orchestrates the loading of both student and teacher models, configures the
@@ -437,9 +443,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
437443
_log_config_details(student_config, "Student")
438444
student_model = get_maxtext_model(student_config, mesh)
439445

440-
max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
441-
_log_config_details(teacher_config, "Teacher")
442-
teacher_model = get_maxtext_model(teacher_config, mesh)
446+
# Skip teacher model loading if offline
447+
if is_offline:
448+
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
449+
teacher_model = None
450+
else:
451+
max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
452+
_log_config_details(teacher_config, "Teacher")
453+
teacher_model = get_maxtext_model(teacher_config, mesh)
454+
teacher_model.eval()
443455

444456
# 3. Define Distillation Strategy
445457
def labels_fn(targets, targets_segmentation=None, **kwargs):
@@ -502,13 +514,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
502514
)
503515

504516
# 5. Data Iterators (Init BEFORE Trainer)
505-
# We use MaxText's native create_data_iterator which creates both train and eval iterators
506-
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
507-
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
517+
if is_offline:
518+
max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...")
519+
raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir)
520+
raw_eval_iter = None
521+
else:
522+
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
523+
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
508524

509-
teacher_model.eval()
510525
student_model.train()
511-
512526
model_bundle = ModelBundle(teacher_model, student_model)
513527

514528
# 6. Initialize Trainer
@@ -526,18 +540,35 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
526540
raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config)
527541

528542
# 8. Configure Input Mapping
529-
trainer = trainer.with_gen_model_input_fn(
530-
lambda batch: {
531-
"input_tokens": batch.input_tokens,
532-
"positions": batch.positions,
533-
"attention_mask": batch.input_mask,
534-
"decoder_segment_ids": batch.decoder_segment_ids,
535-
"targets": batch.targets, # Passed to strategy (labels_fn)
536-
"targets_position": batch.targets_position, # Passed to strategy (labels_fn)
537-
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn)
538-
"cache": None,
539-
}
540-
)
543+
def custom_gen_model_input_fn(batch):
544+
inputs_dict = {
545+
"input_tokens": batch.input_tokens,
546+
"positions": batch.positions,
547+
"attention_mask": batch.input_mask,
548+
"decoder_segment_ids": batch.decoder_segment_ids,
549+
"targets": batch.targets,
550+
"targets_position": batch.targets_position,
551+
"targets_segmentation": batch.targets_segmentation,
552+
"cache": None,
553+
}
554+
555+
# If we are in online mode then we exit
556+
if getattr(batch, "top_k_logits", None) is None:
557+
return inputs_dict
558+
559+
# Scatter the offline arrays into a dense tensor of -10000s
560+
dense_shape = batch.input_tokens.shape + (student_config.vocab_size,)
561+
dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32)
562+
dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False)
563+
564+
# Inject it as teacher_output so the trainer skips the teacher forward pass
565+
inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput(
566+
logits=dense_logits, out_projection_activations=None
567+
)
568+
569+
return inputs_dict
570+
571+
trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)
541572

542573
# 9. Create Iterator Wrappers (Use Utils)
543574
train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter)
@@ -589,9 +620,6 @@ def main(argv: Sequence[str]) -> None:
589620
590621
Parses configuration, isolates Student and Teacher overrides, and triggers the
591622
training loop.
592-
593-
Args:
594-
argv: List of command-line arguments. Expects [script_name, config_file, ...].
595623
"""
596624
# 1. Parse Global Config to extract Overrides
597625
global_config = pyconfig.initialize(argv)
@@ -601,12 +629,14 @@ def main(argv: Sequence[str]) -> None:
601629
student_overrides = global_config.student_overrides
602630
student_config = pyconfig.initialize(argv, **student_overrides)
603631

632+
is_offline = bool(global_config.offline_data_dir)
633+
604634
# 3. Initialize TEACHER Config
605635
# We isolate the Teacher from Student CLI arguments (like pruning params).
606636
teacher_overrides = global_config.teacher_overrides
607637

608638
# Ensure load_parameters_path is set in overrides
609-
if not teacher_overrides.get("load_parameters_path"):
639+
if not is_offline and not teacher_overrides.get("load_parameters_path"):
610640
raise ValueError(
611641
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
612642
"in your config or arguments."
@@ -618,7 +648,7 @@ def main(argv: Sequence[str]) -> None:
618648
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
619649

620650
# 4. Run Training
621-
train_distill(student_config, teacher_config)
651+
train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
622652

623653

624654
if __name__ == "__main__":

0 commit comments

Comments
 (0)