Skip to content

Commit 05bbde3

Browse files
Merge pull request #2927 from AI-Hypercomputer:agagik-distill-2
PiperOrigin-RevId: 856367359
2 parents e171f7a + 6bd9e12 commit 05bbde3

1 file changed

Lines changed: 43 additions & 31 deletions

File tree

src/MaxText/distillation/train_distill.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434
"""
3535

36-
import os
3736
from typing import Any, Iterator, Sequence, Dict, Tuple
3837

3938
from absl import app
@@ -54,7 +53,6 @@
5453
from MaxText import pyconfig
5554
from MaxText import tokenizer
5655
from MaxText import train_utils
57-
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
5856

5957
# Tunix Imports
6058
from tunix.distillation import distillation_trainer
@@ -123,6 +121,34 @@ def optimizer_factory(learning_rate):
123121
return optimizer
124122

125123

124+
def create_forward_fn(config: pyconfig.HyperParameters):
125+
"""Creates a forward function closure that binds the specific model configuration.
126+
127+
Args:
128+
config: The HyperParameters object for the specific model being wrapped.
129+
130+
Returns:
131+
A callable `model_forward_fn` that matches the signature expected by the
132+
Tunix `LogitStrategy` and handles the MaxText-specific forward call.
133+
"""
134+
135+
def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs):
136+
"""Forward pass wrapper adapted for raw MaxText models."""
137+
del kwargs # Unused
138+
del attention_mask # Unused
139+
del cache # Unused
140+
141+
logits = model(
142+
decoder_input_tokens=input_tokens,
143+
decoder_positions=positions,
144+
decoder_segment_ids=decoder_segment_ids,
145+
enable_dropout=config.enable_dropout,
146+
)
147+
return logits
148+
149+
return model_forward_fn
150+
151+
126152
# -----------------------------------------------------------------------------
127153
# Custom Data Structures & Strategies
128154
# -----------------------------------------------------------------------------
@@ -337,21 +363,18 @@ def __next__(self) -> MaxTextTrainingInput:
337363
# Model Loading
338364
# -----------------------------------------------------------------------------
339365
def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) -> nnx.Module:
340-
"""Loads a MaxText model and wraps it in a Tunix adapter.
366+
"""Loads a MaxText model.
341367
342368
Args:
343369
config: The configuration object for this specific model (Student or Teacher).
344370
mesh: The global device mesh for sharding weights.
345371
346372
Returns:
347-
A TunixMaxTextAdapter instance wrapping the loaded MaxText model.
373+
The loaded MaxText model.
348374
"""
349375
max_logging.log(f"Initializing model: {config.model_name}...")
350376
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
351-
352-
with mesh:
353-
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=True)
354-
return tunix_model
377+
return model
355378

356379

357380
# -----------------------------------------------------------------------------
@@ -408,22 +431,9 @@ def labels_fn(targets, **kwargs):
408431
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
409432
return one_hot * mask
410433

411-
def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs):
412-
"""Forward pass wrapper for the MaxText models (Student and Teacher)."""
413-
del kwargs # Unused
414-
# Tunix adapter ensures __call__ signature matches this
415-
outputs = model(
416-
input_tokens=input_tokens,
417-
positions=positions,
418-
cache=cache,
419-
attention_mask=attention_mask,
420-
decoder_segment_ids=decoder_segment_ids, # Support sequence packing
421-
)
422-
return outputs[0] # Return logits only
423-
424434
# Both Student and Teacher use the same forward logic via the adapter
425-
student_forward_fn = model_forward_fn
426-
teacher_forward_fn = model_forward_fn
435+
student_forward_fn = create_forward_fn(student_config)
436+
teacher_forward_fn = create_forward_fn(teacher_config)
427437

428438
# Use Monitored strategy to enable KL/Soft/Hard Loss logging
429439
strategy = MonitoredLogitStrategy(
@@ -438,7 +448,10 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
438448
optimizer = get_distillation_optimizer(student_config, student_config.steps)
439449

440450
checkpointing_options = checkpoint.CheckpointManagerOptions(
441-
save_interval_steps=student_config.checkpoint_period, max_to_keep=student_config.max_num_checkpoints_to_keep
451+
save_interval_steps=student_config.checkpoint_period,
452+
max_to_keep=student_config.max_num_checkpoints_to_keep,
453+
enable_async_checkpointing=student_config.async_checkpointing,
454+
create=True,
442455
)
443456

444457
profiler_options = None
@@ -477,7 +490,7 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
477490
trainer._has_aux = True # pylint: disable=protected-access
478491

479492
# 6. Configure Input Mapping
480-
# Maps the attributes of MaxTextTrainingInput to the kwargs expected by the models
493+
# Maps the attributes of MaxTextTrainingInput to the kwargs expected by model_forward_fn
481494
trainer = trainer.with_gen_model_input_fn(
482495
lambda batch: {
483496
"input_tokens": batch.input_tokens,
@@ -560,13 +573,12 @@ def main(argv: Sequence[str]) -> None:
560573
# We isolate the Teacher from Student CLI arguments (like pruning params).
561574
teacher_overrides = global_config.teacher_overrides
562575

563-
# Ensure load_parameters_path is set (check overrides, then env var)
576+
# Ensure load_parameters_path is set in overrides
564577
if not teacher_overrides.get("load_parameters_path"):
565-
ckpt_path = os.environ.get("TEACHER_CHECKPOINT_PATH")
566-
if ckpt_path:
567-
teacher_overrides["load_parameters_path"] = ckpt_path
568-
else:
569-
max_logging.log("Warning: No load_parameters_path found for Teacher.")
578+
raise ValueError(
579+
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
580+
"in your config or arguments."
581+
)
570582

571583
# Construct sanitized argv: [script_name, config_file]
572584
# This ensures flags like `num_query_heads=16` passed in CLI don't affect the Teacher.

0 commit comments

Comments
 (0)