Skip to content

Commit 422a3be

Browse files
committed
run teacher/student in same jit function for performance.
1 parent 27b3790 commit 422a3be

4 files changed

Lines changed: 292 additions & 48 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from maxtext.utils import max_logging
3131
# Reuse MaxText's native checkpointing logic
3232
from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore
33-
from tunix.distillation import distillation_trainer
34-
from tunix.distillation.strategies import logit
33+
from tunix.sft import peft_trainer
3534
from tunix.sft import checkpoint_manager as tunix_checkpoint_manager
3635

3736

@@ -51,7 +50,7 @@ class DistillationForwardOutput:
5150

5251

5352
@flax.struct.dataclass(frozen=True)
54-
class MaxTextTrainingInput(distillation_trainer.TrainingInput):
53+
class MaxTextTrainingInput(peft_trainer.TrainingInput):
5554
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
5655

5756
#: Position indices for the tokens (for RoPE).
@@ -119,7 +118,6 @@ def __next__(self) -> MaxTextTrainingInput:
119118
return MaxTextTrainingInput(
120119
input_tokens=batch["inputs"],
121120
input_mask=input_mask,
122-
teacher_output=None,
123121
positions=batch["inputs_position"],
124122
decoder_segment_ids=seg_ids,
125123
targets=batch["targets"],
@@ -131,8 +129,8 @@ def __next__(self) -> MaxTextTrainingInput:
131129
# -----------------------------------------------------------------------------
132130
# Distillation Strategy
133131
# -----------------------------------------------------------------------------
134-
class CombinedDistillationStrategy(logit.LogitStrategy):
135-
"""Logit Strategy that returns detailed metrics for TensorBoard."""
132+
class CombinedDistillationStrategy:
133+
"""Strategy that returns detailed metrics for TensorBoard."""
136134

137135
def __init__(
138136
self,
@@ -150,25 +148,23 @@ def __init__(
150148
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
151149
152150
Args:
153-
student_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute student model outputs.
154-
teacher_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute teacher model outputs.
155-
labels_fn: Inherited from `logit.LogitStrategy`. Function to compute labels from model inputs.
156-
temperature: Inherited from `logit.LogitStrategy`. Temperature for softening probabilities (> 0).
157-
alpha: Inherited from `logit.LogitStrategy`. Weight to balance distillation loss and task loss (0.0 to 1.0).
151+
student_forward_fn: Function to compute student model outputs.
152+
teacher_forward_fn: Function to compute teacher model outputs.
153+
labels_fn: Function to compute labels from model inputs.
154+
temperature: Temperature for softening probabilities (> 0).
155+
alpha: Weight to balance distillation loss and task loss (0.0 to 1.0).
158156
beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
159157
layer_indices: Layer indices to apply feature loss.
160158
feature_loss_fn: A function that takes two jax. Arrays (student_map,
161159
teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
162160
cosine_distance_axis: The axis to use for cosine distance computation if
163161
feature_loss_fn is not provided. Defaults to -1.
164162
"""
165-
super().__init__(
166-
student_forward_fn=student_forward_fn,
167-
teacher_forward_fn=teacher_forward_fn,
168-
labels_fn=labels_fn,
169-
temperature=temperature,
170-
alpha=alpha,
171-
)
163+
self.student_forward_fn = student_forward_fn
164+
self.teacher_forward_fn = teacher_forward_fn
165+
self.labels_fn = labels_fn
166+
self.temperature = temperature
167+
self.alpha = alpha
172168
self.beta_feature = beta_feature
173169
self.layer_indices = jnp.array(layer_indices) if layer_indices is not None else None
174170

@@ -325,9 +321,9 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
325321

326322
# Standard Tunix Logic for Model/Optimizer
327323
if save_only_lora_params:
328-
params = nnx.state(model, nnx.LoRAParam)
324+
params = nnx.state(model.student_model, nnx.LoRAParam)
329325
else:
330-
params = nnx.state(model)
326+
params = nnx.state(model.student_model)
331327

332328
# Define standard SaveArgs once to reuse
333329
default_save_args = checkpoint.SaveArgs()

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from maxtext.utils import model_creation_utils
5555

5656
# Tunix Imports
57-
from tunix.distillation import distillation_trainer
57+
from tunix.sft import peft_trainer
5858
from tunix.sft import metrics_logger
5959
from tunix.sft import profiler
6060

@@ -174,13 +174,99 @@ def _log_config_details(config: pyconfig.HyperParameters, label: str) -> None:
174174
max_logging.log(f" Checkpoint: {config.load_parameters_path}")
175175

176176

177-
class MaxTextDistillationTrainer(distillation_trainer.DistillationTrainer):
177+
class ModelBundle(nnx.Module):
178+
"""Wrapper for teacher and student modules."""
179+
180+
def __init__(self, teacher_model: nnx.Module, student_model: nnx.Module):
181+
self.teacher_model = teacher_model
182+
self.student_model = student_model
183+
184+
def __call__(self, *args, **kwargs):
185+
raise NotImplementedError("Use `call_student` or `call_teacher` explicitly.")
186+
187+
def call_student(self, *args, **kwargs):
188+
return self.student_model(*args, **kwargs)
189+
190+
def call_teacher(self, *args, **kwargs):
191+
return jax.lax.stop_gradient(self.teacher_model(*args, **kwargs))
192+
193+
194+
class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
178195
"""Custom Trainer to preserve MaxText fields and log Teacher metrics.
179196
180197
This class overrides `_prepare_inputs` to ensure MaxText-specific fields
181198
(positions, segment_ids) are passed to the model.
182199
"""
183200

201+
def __init__(self, model, strategy, optimizer, training_config, **kwargs):
202+
super().__init__(model=model, optimizer=optimizer, training_config=training_config, **kwargs)
203+
204+
self.strategy = strategy
205+
206+
# override optimizer to only use student_model.
207+
wrt = nnx.LoRAParam if self._lora_enabled else nnx.Param
208+
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=wrt)
209+
210+
def _train_step(self, model, optimizer, inputs):
211+
"""Overrides the main JIT block to natively handle ModelBundle module."""
212+
213+
batch = self.gen_model_input_fn(inputs)
214+
215+
def loss_wrapper(student, teacher, batch):
216+
if "teacher_output" in batch:
217+
teacher_output = batch["teacher_output"]
218+
else:
219+
teacher_output = self.strategy.teacher_forward_fn(
220+
model=teacher,
221+
input_tokens=batch["input_tokens"],
222+
positions=batch["positions"],
223+
attention_mask=batch.get("attention_mask"),
224+
decoder_segment_ids=batch.get("decoder_segment_ids"),
225+
cache=None,
226+
)
227+
228+
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
229+
230+
student_output = self.strategy.student_forward_fn(
231+
model=student,
232+
input_tokens=batch["input_tokens"],
233+
positions=batch["positions"],
234+
attention_mask=batch.get("attention_mask"),
235+
decoder_segment_ids=batch.get("decoder_segment_ids"),
236+
cache=None,
237+
)
238+
labels = self.strategy.labels_fn(batch["targets"])
239+
return self.strategy.compute_loss(student_output, teacher_output, labels)
240+
241+
# Because student is the 0th argument, argnums=0 guarantees
242+
# we only compute gradients for the student.
243+
grad_fn = nnx.value_and_grad(
244+
loss_wrapper,
245+
argnums=0,
246+
has_aux=True,
247+
)
248+
249+
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
250+
251+
optimizer.update(model.student_model, grads)
252+
253+
return out[0], out[1] # loss, aux
254+
255+
def _eval_step(self, model, inputs):
256+
"""Evaluation only needs the student."""
257+
inputs = self.gen_model_input_fn(inputs)
258+
259+
student_output = self.strategy.student_forward_fn(
260+
model=model.student_model,
261+
input_tokens=inputs["input_tokens"],
262+
positions=inputs["positions"],
263+
attention_mask=inputs.get("attention_mask"),
264+
decoder_segment_ids=inputs.get("decoder_segment_ids"),
265+
cache=None,
266+
)
267+
labels = self.strategy.labels_fn(inputs["targets"])
268+
return self.strategy.compute_eval_loss(student_output, labels)
269+
184270
def _prepare_inputs(
185271
self, input_data: distillation_utils.MaxTextTrainingInput
186272
) -> distillation_utils.MaxTextTrainingInput:
@@ -195,22 +281,12 @@ def _prepare_inputs(
195281
Returns:
196282
A new MaxTextTrainingInput containing the Teacher's outputs (logits).
197283
"""
198-
# 1. Generate inputs dictionary for the Teacher model
199-
inputs = self.gen_model_input_fn(input_data)["inputs"]
200-
201-
if self._mode == metrics_logger.Mode.EVAL:
202-
teacher_output = None
203-
else:
204-
# 2. Run Teacher to get soft targets (logits)
205-
# The strategy ensures these are stop_gradient-ed
206-
teacher_output = self.strategy.get_teacher_outputs(self.teacher_model, inputs)
207284

208285
# 3. Return extended object so fields are available for Student training step
209286
# pylint: disable=unexpected-keyword-arg
210287
return distillation_utils.MaxTextTrainingInput(
211288
input_tokens=input_data.input_tokens,
212289
input_mask=input_data.input_mask,
213-
teacher_output=teacher_output,
214290
positions=input_data.positions,
215291
decoder_segment_ids=input_data.decoder_segment_ids,
216292
targets=input_data.targets,
@@ -380,8 +456,6 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
380456
sft_mode=student_config.use_sft,
381457
)
382458

383-
student_model, teacher_model = strategy.pre_process_models(student_model, teacher_model)
384-
385459
# 4. Optimizer & Config
386460
optimizer = get_distillation_optimizer(student_config, student_config.steps)
387461

@@ -405,7 +479,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
405479
log_dir=student_config.tensorboard_dir, flush_every_n_steps=student_config.log_period
406480
)
407481

408-
train_config = distillation_trainer.TrainingConfig(
482+
train_config = peft_trainer.TrainingConfig(
409483
max_steps=student_config.steps,
410484
eval_every_n_steps=student_config.eval_interval,
411485
metrics_logging_options=metrics_logging_options,
@@ -419,10 +493,14 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
419493
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
420494
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
421495

496+
teacher_model.eval()
497+
student_model.train()
498+
499+
model_bundle = ModelBundle(teacher_model, student_model)
500+
422501
# 6. Initialize Trainer
423502
trainer = MaxTextDistillationTrainer(
424-
student_model=student_model,
425-
teacher_model=teacher_model,
503+
model=model_bundle,
426504
strategy=strategy,
427505
optimizer=optimizer,
428506
training_config=train_config,
@@ -472,7 +550,10 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
472550
max_logging.log(f"Saving final checkpoint to {student_config.checkpoint_dir}...")
473551
try:
474552
saved = trainer.checkpoint_manager.save(
475-
trainer.train_steps, trainer.model, save_only_lora_params=getattr(trainer, "_lora_enabled", False), force=True
553+
trainer.train_steps,
554+
trainer.model.student_model,
555+
save_only_lora_params=getattr(trainer, "_lora_enabled", False),
556+
force=True,
476557
)
477558
if saved:
478559
# Ensure underlying orbax manager finishes writing

tests/unit/distillation_checkpointing_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def test_save_and_restore_iterator(self):
9191
)
9292

9393
# Create dummy model so 'model_params' is not empty
94-
model = DummyModel(nnx.Rngs(0))
94+
model = mock.Mock()
95+
model.student_model = DummyModel(nnx.Rngs(0))
9596

9697
# Mock jax.process_index/count to simulate single host
9798
with mock.patch.object(jax, "process_index", return_value=0), mock.patch.object(jax, "process_count", return_value=1):

0 commit comments

Comments
 (0)