32323. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333 a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434"""
35-
3635from typing import Sequence , Callable
3736from absl import app
3837from flax import nnx
@@ -303,6 +302,8 @@ def _prepare_inputs(
303302 targets = input_data .targets ,
304303 targets_position = input_data .targets_position ,
305304 targets_segmentation = input_data .targets_segmentation ,
305+ top_k_logits = input_data .top_k_logits ,
306+ top_k_indices = input_data .top_k_indices ,
306307 )
307308
308309 def _post_process_train_step (self , aux : dict [str , jax .Array ]) -> None :
@@ -401,7 +402,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
401402# -----------------------------------------------------------------------------
402403
403404
404- def train_distill (student_config : pyconfig .HyperParameters , teacher_config : pyconfig .HyperParameters ) -> None :
405+ def train_distill (
406+ student_config : pyconfig .HyperParameters ,
407+ teacher_config : pyconfig .HyperParameters ,
408+ is_offline : bool = False ,
409+ offline_data_dir : str | None = None ,
410+ ) -> None :
405411 """Main distillation training loop.
406412
407413 Orchestrates the loading of both student and teacher models, configures the
@@ -437,9 +443,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
437443 _log_config_details (student_config , "Student" )
438444 student_model = get_maxtext_model (student_config , mesh )
439445
440- max_logging .log (f"Loading Teacher from { teacher_config .load_parameters_path } ..." )
441- _log_config_details (teacher_config , "Teacher" )
442- teacher_model = get_maxtext_model (teacher_config , mesh )
446+ # Skip teacher model loading if offline
447+ if is_offline :
448+ max_logging .log ("Offline Distillation: Skipping Teacher Model loading." )
449+ teacher_model = None
450+ else :
451+ max_logging .log (f"Loading Teacher from { teacher_config .load_parameters_path } ..." )
452+ _log_config_details (teacher_config , "Teacher" )
453+ teacher_model = get_maxtext_model (teacher_config , mesh )
454+ teacher_model .eval ()
443455
444456 # 3. Define Distillation Strategy
445457 def labels_fn (targets , targets_segmentation = None , ** kwargs ):
@@ -502,13 +514,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
502514 )
503515
504516 # 5. Data Iterators (Init BEFORE Trainer)
505- # We use MaxText's native create_data_iterator which creates both train and eval iterators
506- max_logging .log ("Initializing Data Iterators via MaxText pipeline..." )
507- raw_train_iter , raw_eval_iter = input_pipeline_interface .create_data_iterator (student_config , mesh )
517+ if is_offline :
518+ max_logging .log (f"Loading Offline Dataset from { offline_data_dir } ..." )
519+ raw_train_iter = distillation_utils .OfflineArrayRecordIterator (offline_data_dir )
520+ raw_eval_iter = None
521+ else :
522+ max_logging .log ("Initializing Data Iterators via MaxText pipeline..." )
523+ raw_train_iter , raw_eval_iter = input_pipeline_interface .create_data_iterator (student_config , mesh )
508524
509- teacher_model .eval ()
510525 student_model .train ()
511-
512526 model_bundle = ModelBundle (teacher_model , student_model )
513527
514528 # 6. Initialize Trainer
@@ -526,18 +540,35 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
526540 raw_train_iter = _setup_and_restore_input_pipeline (trainer , raw_train_iter , student_config , train_config )
527541
528542 # 8. Configure Input Mapping
529- trainer = trainer .with_gen_model_input_fn (
530- lambda batch : {
531- "input_tokens" : batch .input_tokens ,
532- "positions" : batch .positions ,
533- "attention_mask" : batch .input_mask ,
534- "decoder_segment_ids" : batch .decoder_segment_ids ,
535- "targets" : batch .targets , # Passed to strategy (labels_fn)
536- "targets_position" : batch .targets_position , # Passed to strategy (labels_fn)
537- "targets_segmentation" : batch .targets_segmentation , # Passed to strategy (labels_fn)
538- "cache" : None ,
539- }
540- )
543+ def custom_gen_model_input_fn (batch ):
544+ inputs_dict = {
545+ "input_tokens" : batch .input_tokens ,
546+ "positions" : batch .positions ,
547+ "attention_mask" : batch .input_mask ,
548+ "decoder_segment_ids" : batch .decoder_segment_ids ,
549+ "targets" : batch .targets ,
550+ "targets_position" : batch .targets_position ,
551+ "targets_segmentation" : batch .targets_segmentation ,
552+ "cache" : None ,
553+ }
554+
555+ # If we are in online mode then we exit
556+ if getattr (batch , "top_k_logits" , None ) is None :
557+ return inputs_dict
558+
559+ # Scatter the offline arrays into a dense tensor of -10000s
560+ dense_shape = batch .input_tokens .shape + (student_config .vocab_size ,)
561+ dense_logits = jnp .full (dense_shape , - 10000.0 , dtype = jnp .float32 )
562+ dense_logits = jnp .put_along_axis (dense_logits , batch .top_k_indices , batch .top_k_logits , axis = - 1 , inplace = False )
563+
564+ # Inject it as teacher_output so the trainer skips the teacher forward pass
565+ inputs_dict ["teacher_output" ] = distillation_utils .DistillationForwardOutput (
566+ logits = dense_logits , out_projection_activations = None
567+ )
568+
569+ return inputs_dict
570+
571+ trainer = trainer .with_gen_model_input_fn (custom_gen_model_input_fn )
541572
542573 # 9. Create Iterator Wrappers (Use Utils)
543574 train_iter = distillation_utils .MaxTextToTunixIterator (raw_train_iter )
@@ -589,9 +620,6 @@ def main(argv: Sequence[str]) -> None:
589620
590621 Parses configuration, isolates Student and Teacher overrides, and triggers the
591622 training loop.
592-
593- Args:
594- argv: List of command-line arguments. Expects [script_name, config_file, ...].
595623 """
596624 # 1. Parse Global Config to extract Overrides
597625 global_config = pyconfig .initialize (argv )
@@ -601,12 +629,14 @@ def main(argv: Sequence[str]) -> None:
601629 student_overrides = global_config .student_overrides
602630 student_config = pyconfig .initialize (argv , ** student_overrides )
603631
632+ is_offline = bool (global_config .offline_data_dir )
633+
604634 # 3. Initialize TEACHER Config
605635 # We isolate the Teacher from Student CLI arguments (like pruning params).
606636 teacher_overrides = global_config .teacher_overrides
607637
608638 # Ensure load_parameters_path is set in overrides
609- if not teacher_overrides .get ("load_parameters_path" ):
639+ if not is_offline and not teacher_overrides .get ("load_parameters_path" ):
610640 raise ValueError (
611641 "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
612642 "in your config or arguments."
@@ -618,7 +648,7 @@ def main(argv: Sequence[str]) -> None:
618648 teacher_config = pyconfig .initialize (teacher_argv , ** teacher_overrides )
619649
620650 # 4. Run Training
621- train_distill (student_config , teacher_config )
651+ train_distill (student_config , teacher_config , is_offline , global_config . offline_data_dir )
622652
623653
624654if __name__ == "__main__" :
0 commit comments