@@ -60,6 +60,10 @@ class MaxTextTrainingInput(distillation_trainer.TrainingInput):
6060 decoder_segment_ids : jax .Array = None
6161 #: Ground truth target tokens (used for loss calculation and logging).
6262 targets : jax .Array = None
63+ #: Position indices for the target tokens.
64+ targets_position : jax .Array = None
65+ #: Segment IDs for packed target tokens.
66+ targets_segmentation : jax .Array = None
6367
6468
6569# -----------------------------------------------------------------------------
@@ -106,6 +110,11 @@ def __next__(self) -> MaxTextTrainingInput:
106110 input_mask = jnp .ones_like (batch ["inputs" ], dtype = bool )
107111 seg_ids = None
108112
113+ # If in SFT-mode, 'targets' contains prompts which should be masked out when computing the loss.
114+ # If using with packing the targets_segmentation mask is supposed to be a combined target+packing mask
115+ targets_segmentation = batch .get ("targets_segmentation" , jnp .ones_like (batch ["targets" ]))
116+ targets_position = batch .get ("targets_position" , batch .get ("inputs_position" ))
117+
109118 # pylint: disable=unexpected-keyword-arg
110119 return MaxTextTrainingInput (
111120 input_tokens = batch ["inputs" ],
@@ -114,6 +123,8 @@ def __next__(self) -> MaxTextTrainingInput:
114123 positions = batch ["inputs_position" ],
115124 decoder_segment_ids = seg_ids ,
116125 targets = batch ["targets" ],
126+ targets_position = targets_position ,
127+ targets_segmentation = targets_segmentation ,
117128 )
118129
119130
@@ -134,6 +145,7 @@ def __init__(
134145 layer_indices : Optional [List [int ]] = None ,
135146 feature_loss_fn : Callable [[jax .Array , jax .Array ], jax .Array ] | None = None ,
136147 cosine_distance_axis : int | tuple [int , ...] = - 1 ,
148+ sft_mode : bool = False ,
137149 ):
138150 """Initializes the Combined strategy using tunix logit.LogitStrategy.
139151
@@ -165,6 +177,7 @@ def __init__(
165177 self .feature_loss_fn = lambda student_features , teacher_features : jnp .mean (
166178 optax .cosine_distance (student_features , teacher_features , axis = cosine_distance_axis )
167179 )
180+ self .sft_mode = sft_mode
168181
169182 def compute_loss (
170183 self ,
@@ -192,19 +205,23 @@ def compute_loss(
192205 log_student_probs_temp = jax .nn .log_softmax (s_logits / self .temperature , axis = - 1 )
193206 teacher_probs_temp = jax .nn .softmax (t_logits / self .temperature , axis = - 1 )
194207
208+ # labels are supposed to have all sft masks applied by this moment
209+ labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True ) if self .sft_mode else None
210+ mean_mask = jnp .squeeze (labels_mask , axis = - 1 ) if labels_mask is not None else None
211+
195212 # KL(Teacher || Student)
196- kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp )
213+ kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp , where = labels_mask )
197214
198215 # Scale gradients by T^2 (Hinton et al.)
199- soft_loss = jnp .mean (kl_div ) * (self .temperature ** 2 )
216+ soft_loss = jnp .mean (kl_div , where = mean_mask ) * (self .temperature ** 2 )
200217
201218 # 1. Student Hard Loss (Existing)
202- ce_loss_student = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
203- hard_loss = jnp .mean (ce_loss_student )
219+ ce_loss_student = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
220+ hard_loss = jnp .mean (ce_loss_student , where = mean_mask )
204221
205222 # 2. Teacher Hard Loss (For Verification)
206- ce_loss_teacher = optax .softmax_cross_entropy (logits = t_logits , labels = labels )
207- teacher_hard_loss = jnp .mean (ce_loss_teacher )
223+ ce_loss_teacher = optax .softmax_cross_entropy (logits = t_logits , labels = labels , where = labels_mask )
224+ teacher_hard_loss = jnp .mean (ce_loss_teacher , where = mean_mask )
208225
209226 # 3. Combine losses
210227 base_logit_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
0 commit comments