3333 a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434"""
3535
36- import os
3736from typing import Any , Iterator , Sequence , Dict , Tuple
3837
3938from absl import app
5453from MaxText import pyconfig
5554from MaxText import tokenizer
5655from MaxText import train_utils
57- from MaxText .integration .tunix .tunix_adapter import TunixMaxTextAdapter
5856
5957# Tunix Imports
6058from tunix .distillation import distillation_trainer
@@ -123,6 +121,34 @@ def optimizer_factory(learning_rate):
123121 return optimizer
124122
125123
124+ def create_forward_fn (config : pyconfig .HyperParameters ):
125+ """Creates a forward function closure that binds the specific model configuration.
126+
127+ Args:
128+ config: The HyperParameters object for the specific model being wrapped.
129+
130+ Returns:
131+ A callable `model_forward_fn` that matches the signature expected by the
132+ Tunix `LogitStrategy` and handles the MaxText-specific forward call.
133+ """
134+
135+ def model_forward_fn (model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None , ** kwargs ):
136+ """Forward pass wrapper adapted for raw MaxText models."""
137+ del kwargs # Unused
138+ del attention_mask # Unused
139+ del cache # Unused
140+
141+ logits = model (
142+ decoder_input_tokens = input_tokens ,
143+ decoder_positions = positions ,
144+ decoder_segment_ids = decoder_segment_ids ,
145+ enable_dropout = config .enable_dropout ,
146+ )
147+ return logits
148+
149+ return model_forward_fn
150+
151+
126152# -----------------------------------------------------------------------------
127153# Custom Data Structures & Strategies
128154# -----------------------------------------------------------------------------
@@ -337,21 +363,18 @@ def __next__(self) -> MaxTextTrainingInput:
337363# Model Loading
338364# -----------------------------------------------------------------------------
339365def get_maxtext_model (config : pyconfig .HyperParameters , mesh : jax .sharding .Mesh ) -> nnx .Module :
340- """Loads a MaxText model and wraps it in a Tunix adapter .
366+ """Loads a MaxText model.
341367
342368 Args:
343369 config: The configuration object for this specific model (Student or Teacher).
344370 mesh: The global device mesh for sharding weights.
345371
346372 Returns:
347- A TunixMaxTextAdapter instance wrapping the loaded MaxText model.
373+ The loaded MaxText model.
348374 """
349375 max_logging .log (f"Initializing model: { config .model_name } ..." )
350376 model , _ = model_creation_utils .create_nnx_model (config , mesh = mesh )
351-
352- with mesh :
353- tunix_model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = True )
354- return tunix_model
377+ return model
355378
356379
357380# -----------------------------------------------------------------------------
@@ -408,22 +431,9 @@ def labels_fn(targets, **kwargs):
408431 mask = jnp .not_equal (targets , pad_id ).astype (one_hot .dtype )[..., None ]
409432 return one_hot * mask
410433
411- def model_forward_fn (model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None , ** kwargs ):
412- """Forward pass wrapper for the MaxText models (Student and Teacher)."""
413- del kwargs # Unused
414- # Tunix adapter ensures __call__ signature matches this
415- outputs = model (
416- input_tokens = input_tokens ,
417- positions = positions ,
418- cache = cache ,
419- attention_mask = attention_mask ,
420- decoder_segment_ids = decoder_segment_ids , # Support sequence packing
421- )
422- return outputs [0 ] # Return logits only
423-
424434 # Both Student and Teacher use the same forward logic via the adapter
425- student_forward_fn = model_forward_fn
426- teacher_forward_fn = model_forward_fn
435+ student_forward_fn = create_forward_fn ( student_config )
436+ teacher_forward_fn = create_forward_fn ( teacher_config )
427437
428438 # Use Monitored strategy to enable KL/Soft/Hard Loss logging
429439 strategy = MonitoredLogitStrategy (
@@ -438,7 +448,10 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
438448 optimizer = get_distillation_optimizer (student_config , student_config .steps )
439449
440450 checkpointing_options = checkpoint .CheckpointManagerOptions (
441- save_interval_steps = student_config .checkpoint_period , max_to_keep = student_config .max_num_checkpoints_to_keep
451+ save_interval_steps = student_config .checkpoint_period ,
452+ max_to_keep = student_config .max_num_checkpoints_to_keep ,
453+ enable_async_checkpointing = student_config .async_checkpointing ,
454+ create = True ,
442455 )
443456
444457 profiler_options = None
@@ -477,7 +490,7 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
477490 trainer ._has_aux = True # pylint: disable=protected-access
478491
479492 # 6. Configure Input Mapping
480- # Maps the attributes of MaxTextTrainingInput to the kwargs expected by the models
493+ # Maps the attributes of MaxTextTrainingInput to the kwargs expected by model_forward_fn
481494 trainer = trainer .with_gen_model_input_fn (
482495 lambda batch : {
483496 "input_tokens" : batch .input_tokens ,
@@ -560,13 +573,12 @@ def main(argv: Sequence[str]) -> None:
560573 # We isolate the Teacher from Student CLI arguments (like pruning params).
561574 teacher_overrides = global_config .teacher_overrides
562575
563- # Ensure load_parameters_path is set (check overrides, then env var)
576+ # Ensure load_parameters_path is set in overrides
564577 if not teacher_overrides .get ("load_parameters_path" ):
565- ckpt_path = os .environ .get ("TEACHER_CHECKPOINT_PATH" )
566- if ckpt_path :
567- teacher_overrides ["load_parameters_path" ] = ckpt_path
568- else :
569- max_logging .log ("Warning: No load_parameters_path found for Teacher." )
578+ raise ValueError (
579+ "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
580+ "in your config or arguments."
581+ )
570582
571583 # Construct sanitized argv: [script_name, config_file]
572584 # This ensures flags like `num_query_heads=16` passed in CLI don't affect the Teacher.
0 commit comments