Skip to content

Commit 0d80d3d

Browse files
committed
fix sft with after a recent distillation train code refactor
1 parent f99723a commit 0d80d3d

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distill
131131
"""
132132

133133
def model_forward_fn(
134-
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs
134+
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None,
135+
**kwargs
135136
) -> distillation_utils.DistillationForwardOutput:
136137
"""Forward pass wrapper adapted for raw MaxText models."""
137138
del attention_mask # Unused
@@ -141,8 +142,8 @@ def model_forward_fn(
141142
decoder_positions=positions,
142143
decoder_segment_ids=decoder_segment_ids,
143144
enable_dropout=config.enable_dropout,
144-
decoder_target_tokens=kwargs.get("targets", None),
145-
decoder_target_mask=kwargs.get("targets_segmentation", None),
145+
decoder_target_tokens=kwargs.get("decoder_target_tokens", None),
146+
decoder_target_mask=kwargs.get("decoder_target_mask", None),
146147
)
147148
out_projection_activations = None
148149
if config.distill_beta > 0.0:
@@ -214,7 +215,7 @@ def _train_step(self, model, optimizer, inputs):
214215

215216
batch = self.gen_model_input_fn(inputs)
216217

217-
def loss_wrapper(student, teacher, batch):
218+
def loss_wrapper(student, teacher, batch):
218219
if "teacher_output" in batch:
219220
teacher_output = batch["teacher_output"]
220221
else:
@@ -224,6 +225,8 @@ def loss_wrapper(student, teacher, batch):
224225
positions=batch["positions"],
225226
attention_mask=batch.get("attention_mask"),
226227
decoder_segment_ids=batch.get("decoder_segment_ids"),
228+
decoder_target_tokens=batch.get("targets", None),
229+
decoder_target_mask=batch.get("targets_segmentation", None),
227230
cache=None,
228231
)
229232

@@ -235,9 +238,12 @@ def loss_wrapper(student, teacher, batch):
235238
positions=batch["positions"],
236239
attention_mask=batch.get("attention_mask"),
237240
decoder_segment_ids=batch.get("decoder_segment_ids"),
241+
decoder_target_tokens=batch.get("targets", None),
242+
decoder_target_mask=batch.get("targets_segmentation", None),
238243
cache=None,
239244
)
240-
labels = self.strategy.labels_fn(batch["targets"])
245+
# we should apply a mask for labels to disable segment-separator tokens
246+
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
241247
return self.strategy.compute_loss(student_output, teacher_output, labels)
242248

243249
# Because student is the 0th argument, argnums=0 guarantees
@@ -434,7 +440,7 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
434440

435441
# 3. Define Distillation Strategy
436442
def labels_fn(targets, targets_segmentation=None, **kwargs):
437-
"""Converts integer targets to masked one-hot vectors for hard label loss."""
443+
"""Converts integer targets to masked one-hot vectors for hard label loss."""
438444
del kwargs # Unused
439445
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
440446
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]

0 commit comments

Comments
 (0)