@@ -307,7 +307,7 @@ def _prepare_inputs(
307307 targets_position = input_data .targets_position ,
308308 targets_segmentation = input_data .targets_segmentation ,
309309 top_k_logits = input_data .top_k_logits ,
310- top_k_indices = input_data .top_k_indices
310+ top_k_indices = input_data .top_k_indices ,
311311 )
312312
313313 def _post_process_train_step (self , aux : dict [str , jax .Array ]) -> None :
@@ -406,7 +406,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
406406# -----------------------------------------------------------------------------
407407
408408
409- def train_distill (student_config : pyconfig .HyperParameters , teacher_config : pyconfig .HyperParameters , is_offline : bool = False , offline_data_dir : str | None = None ) -> None :
409+ def train_distill (
410+ student_config : pyconfig .HyperParameters ,
411+ teacher_config : pyconfig .HyperParameters ,
412+ is_offline : bool = False ,
413+ offline_data_dir : str | None = None ,
414+ ) -> None :
410415 """Main distillation training loop.
411416
412417 Orchestrates the loading of both student and teacher models, configures the
@@ -550,29 +555,23 @@ def custom_gen_model_input_fn(batch):
550555 "targets_segmentation" : batch .targets_segmentation ,
551556 "cache" : None ,
552557 }
553-
558+
554559 # If we are in online mode then we exit
555560 if getattr (batch , "top_k_logits" , None ) is None :
556561 return inputs_dict
557562
558563 # Scatter the offline arrays into a dense tensor of -10000s
559564 dense_shape = batch .input_tokens .shape + (student_config .vocab_size ,)
560565 dense_logits = jnp .full (dense_shape , - 10000.0 , dtype = jnp .float32 )
561- dense_logits = jnp .put_along_axis (
562- dense_logits ,
563- batch .top_k_indices ,
564- batch .top_k_logits ,
565- axis = - 1 ,
566- inplace = False
567- )
568-
566+ dense_logits = jnp .put_along_axis (dense_logits , batch .top_k_indices , batch .top_k_logits , axis = - 1 , inplace = False )
567+
569568 # Inject it as teacher_output so the trainer skips the teacher forward pass
570569 inputs_dict ["teacher_output" ] = distillation_utils .DistillationForwardOutput (
571570 logits = dense_logits , out_projection_activations = None
572571 )
573-
572+
574573 return inputs_dict
575-
574+
576575 trainer = trainer .with_gen_model_input_fn (custom_gen_model_input_fn )
577576
578577 # 9. Create Iterator Wrappers (Use Utils)
@@ -635,7 +634,7 @@ def main(argv: Sequence[str], local_args) -> None:
635634 student_config = pyconfig .initialize (argv , ** student_overrides )
636635
637636 # 3. Initialize TEACHER Config
638- # We isolate the Teacher from Student CLI arguments (like pruning params).
637+ # We isolate the Teacher from Student CLI arguments (like pruning params).
639638 teacher_overrides = global_config .teacher_overrides
640639
641640 # Ensure load_parameters_path is set in overrides
@@ -668,7 +667,7 @@ def main(argv: Sequence[str], local_args) -> None:
668667 default = None ,
669668 help = "GCS or local path to the pre-generated ArrayRecord teacher data." ,
670669 )
671-
670+
672671 # parse_known_args separates our custom flags from MaxText's standard args
673672 local_arg , remaining_args = parser .parse_known_args ()
674673
0 commit comments