32323. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333 a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434"""
35+ import argparse
36+ import functools
37+ import sys
3538
3639from typing import Sequence , Callable
3740from absl import app
@@ -303,6 +306,8 @@ def _prepare_inputs(
303306 targets = input_data .targets ,
304307 targets_position = input_data .targets_position ,
305308 targets_segmentation = input_data .targets_segmentation ,
309+ top_k_logits = input_data .top_k_logits ,
310+ top_k_indices = input_data .top_k_indices
306311 )
307312
308313 def _post_process_train_step (self , aux : dict [str , jax .Array ]) -> None :
@@ -401,7 +406,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
401406# -----------------------------------------------------------------------------
402407
403408
404- def train_distill (student_config : pyconfig .HyperParameters , teacher_config : pyconfig .HyperParameters ) -> None :
409+ def train_distill (student_config : pyconfig .HyperParameters , teacher_config : pyconfig .HyperParameters , is_offline : bool = False , offline_data_dir : str | None = None ) -> None :
405410 """Main distillation training loop.
406411
407412 Orchestrates the loading of both student and teacher models, configures the
@@ -437,9 +442,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
437442 _log_config_details (student_config , "Student" )
438443 student_model = get_maxtext_model (student_config , mesh )
439444
440- max_logging .log (f"Loading Teacher from { teacher_config .load_parameters_path } ..." )
441- _log_config_details (teacher_config , "Teacher" )
442- teacher_model = get_maxtext_model (teacher_config , mesh )
445+ # Skip teacher model loading if offline
446+ if is_offline :
447+ max_logging .log ("Offline Distillation: Skipping Teacher Model loading." )
448+ teacher_model = None
449+ else :
450+ max_logging .log (f"Loading Teacher from { teacher_config .load_parameters_path } ..." )
451+ _log_config_details (teacher_config , "Teacher" )
452+ teacher_model = get_maxtext_model (teacher_config , mesh )
453+ teacher_model .eval ()
443454
444455 # 3. Define Distillation Strategy
445456 def labels_fn (targets , targets_segmentation = None , ** kwargs ):
@@ -502,13 +513,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
502513 )
503514
504515 # 5. Data Iterators (Init BEFORE Trainer)
505- # We use MaxText's native create_data_iterator which creates both train and eval iterators
506- max_logging .log ("Initializing Data Iterators via MaxText pipeline..." )
507- raw_train_iter , raw_eval_iter = input_pipeline_interface .create_data_iterator (student_config , mesh )
516+ if is_offline :
517+ max_logging .log (f"Loading Offline Dataset from { offline_data_dir } ..." )
518+ raw_train_iter = distillation_utils .OfflineArrayRecordIterator (offline_data_dir )
519+ raw_eval_iter = None
520+ else :
521+ max_logging .log ("Initializing Data Iterators via MaxText pipeline..." )
522+ raw_train_iter , raw_eval_iter = input_pipeline_interface .create_data_iterator (student_config , mesh )
508523
509- teacher_model .eval ()
510524 student_model .train ()
511-
512525 model_bundle = ModelBundle (teacher_model , student_model )
513526
514527 # 6. Initialize Trainer
@@ -526,18 +539,41 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
526539 raw_train_iter = _setup_and_restore_input_pipeline (trainer , raw_train_iter , student_config , train_config )
527540
528541 # 8. Configure Input Mapping
529- trainer = trainer .with_gen_model_input_fn (
530- lambda batch : {
531- "input_tokens" : batch .input_tokens ,
532- "positions" : batch .positions ,
533- "attention_mask" : batch .input_mask ,
534- "decoder_segment_ids" : batch .decoder_segment_ids ,
535- "targets" : batch .targets , # Passed to strategy (labels_fn)
536- "targets_position" : batch .targets_position , # Passed to strategy (labels_fn)
537- "targets_segmentation" : batch .targets_segmentation , # Passed to strategy (labels_fn)
538- "cache" : None ,
539- }
540- )
542+ def custom_gen_model_input_fn (batch ):
543+ inputs_dict = {
544+ "input_tokens" : batch .input_tokens ,
545+ "positions" : batch .positions ,
546+ "attention_mask" : batch .input_mask ,
547+ "decoder_segment_ids" : batch .decoder_segment_ids ,
548+ "targets" : batch .targets ,
549+ "targets_position" : batch .targets_position ,
550+ "targets_segmentation" : batch .targets_segmentation ,
551+ "cache" : None ,
552+ }
553+
554+ # If we are in online mode then we exit
555+ if getattr (batch , "top_k_logits" , None ) is None :
556+ return inputs_dict
557+
558+ # Scatter the offline arrays into a dense tensor of -10000s
559+ dense_shape = batch .input_tokens .shape + (student_config .vocab_size ,)
560+ dense_logits = jnp .full (dense_shape , - 10000.0 , dtype = jnp .float32 )
561+ dense_logits = jnp .put_along_axis (
562+ dense_logits ,
563+ batch .top_k_indices ,
564+ batch .top_k_logits ,
565+ axis = - 1 ,
566+ inplace = False
567+ )
568+
569+ # Inject it as teacher_output so the trainer skips the teacher forward pass
570+ inputs_dict ["teacher_output" ] = distillation_utils .DistillationForwardOutput (
571+ logits = dense_logits , out_projection_activations = None
572+ )
573+
574+ return inputs_dict
575+
576+ trainer = trainer .with_gen_model_input_fn (custom_gen_model_input_fn )
541577
542578 # 9. Create Iterator Wrappers (Use Utils)
543579 train_iter = distillation_utils .MaxTextToTunixIterator (raw_train_iter )
@@ -584,14 +620,11 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
584620 max_logging .log ("Distillation Complete." )
585621
586622
587- def main (argv : Sequence [str ]) -> None :
623+ def main (argv : Sequence [str ], local_args ) -> None :
588624 """Entry point for the script.
589625
590626 Parses configuration, isolates Student and Teacher overrides, and triggers the
591627 training loop.
592-
593- Args:
594- argv: List of command-line arguments. Expects [script_name, config_file, ...].
595628 """
596629 # 1. Parse Global Config to extract Overrides
597630 global_config = pyconfig .initialize (argv )
@@ -602,11 +635,11 @@ def main(argv: Sequence[str]) -> None:
602635 student_config = pyconfig .initialize (argv , ** student_overrides )
603636
604637 # 3. Initialize TEACHER Config
605- # We isolate the Teacher from Student CLI arguments (like pruning params).
638+ # We isolate the Teacher from Student CLI arguments (like pruning params).
606639 teacher_overrides = global_config .teacher_overrides
607640
608641 # Ensure load_parameters_path is set in overrides
609- if not teacher_overrides .get ("load_parameters_path" ):
642+ if not local_args . offline_distillation and not teacher_overrides .get ("load_parameters_path" ):
610643 raise ValueError (
611644 "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
612645 "in your config or arguments."
@@ -618,8 +651,26 @@ def main(argv: Sequence[str]) -> None:
618651 teacher_config = pyconfig .initialize (teacher_argv , ** teacher_overrides )
619652
620653 # 4. Run Training
621- train_distill (student_config , teacher_config )
654+ train_distill (student_config , teacher_config , local_args . offline_distillation , local_args . offline_data_dir )
622655
623656
624657if __name__ == "__main__" :
625- app .run (main )
658+ parser = argparse .ArgumentParser ()
659+ parser .add_argument (
660+ "--offline_distillation" ,
661+ action = "store_true" ,
662+ help = "Pass this flag to enable offline distillation." ,
663+ )
664+ parser .add_argument (
665+ "--offline_data_dir" ,
666+ type = str ,
667+ required = False ,
668+ default = None ,
669+ help = "GCS or local path to the pre-generated ArrayRecord teacher data." ,
670+ )
671+
672+ # parse_known_args separates our custom flags from MaxText's standard args
673+ local_arg , remaining_args = parser .parse_known_args ()
674+
675+ main_wrapper = functools .partial (main , local_args = local_arg )
676+ app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
0 commit comments