@@ -131,7 +131,8 @@ def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distill
131131 """
132132
133133 def model_forward_fn (
134- model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None , ** kwargs
134+ model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None ,
135+ ** kwargs
135136 ) -> distillation_utils .DistillationForwardOutput :
136137 """Forward pass wrapper adapted for raw MaxText models."""
137138 del attention_mask # Unused
@@ -141,8 +142,8 @@ def model_forward_fn(
141142 decoder_positions = positions ,
142143 decoder_segment_ids = decoder_segment_ids ,
143144 enable_dropout = config .enable_dropout ,
144- decoder_target_tokens = kwargs .get ("targets " , None ),
145- decoder_target_mask = kwargs .get ("targets_segmentation " , None ),
145+ decoder_target_tokens = kwargs .get ("decoder_target_tokens " , None ),
146+ decoder_target_mask = kwargs .get ("decoder_target_mask " , None ),
146147 )
147148 out_projection_activations = None
148149 if config .distill_beta > 0.0 :
@@ -214,7 +215,7 @@ def _train_step(self, model, optimizer, inputs):
214215
215216 batch = self .gen_model_input_fn (inputs )
216217
217- def loss_wrapper (student , teacher , batch ):
218+ def loss_wrapper (student , teacher , batch ):
218219 if "teacher_output" in batch :
219220 teacher_output = batch ["teacher_output" ]
220221 else :
@@ -224,6 +225,8 @@ def loss_wrapper(student, teacher, batch):
224225 positions = batch ["positions" ],
225226 attention_mask = batch .get ("attention_mask" ),
226227 decoder_segment_ids = batch .get ("decoder_segment_ids" ),
228+ decoder_target_tokens = batch .get ("targets" , None ),
229+ decoder_target_mask = batch .get ("targets_segmentation" , None ),
227230 cache = None ,
228231 )
229232
@@ -235,9 +238,12 @@ def loss_wrapper(student, teacher, batch):
235238 positions = batch ["positions" ],
236239 attention_mask = batch .get ("attention_mask" ),
237240 decoder_segment_ids = batch .get ("decoder_segment_ids" ),
241+ decoder_target_tokens = batch .get ("targets" , None ),
242+ decoder_target_mask = batch .get ("targets_segmentation" , None ),
238243 cache = None ,
239244 )
240- labels = self .strategy .labels_fn (batch ["targets" ])
245+ # we should apply a mask for labels to disable segment-separator tokens
246+ labels = self .strategy .labels_fn (batch ["targets" ], targets_segmentation = batch .get ("targets_segmentation" , None ))
241247 return self .strategy .compute_loss (student_output , teacher_output , labels )
242248
243249 # Because student is the 0th argument, argnums=0 guarantees
@@ -434,7 +440,7 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
434440
435441 # 3. Define Distillation Strategy
436442 def labels_fn (targets , targets_segmentation = None , ** kwargs ):
437- """Converts integer targets to masked one-hot vectors for hard label loss."""
443+ """Converts integer targets to masked one-hot vectors for hard label loss."""
438444 del kwargs # Unused
439445 one_hot = jax .nn .one_hot (targets , student_config .vocab_size )
440446 mask = jnp .not_equal (targets , pad_id ).astype (one_hot .dtype )[..., None ]
0 commit comments