1818model structures with Tunix's training interfaces.
1919"""
2020
21- from typing import Any , Iterator
21+ from typing import Any , Iterator , Optional , List , Callable
2222
2323import flax
2424from flax import nnx
@@ -110,9 +110,52 @@ def __next__(self) -> MaxTextTrainingInput:
110110# -----------------------------------------------------------------------------
111111# Distillation Strategy
112112# -----------------------------------------------------------------------------
113- class MonitoredLogitStrategy (logit .LogitStrategy ):
113+ class CombinedDistillationStrategy (logit .LogitStrategy ):
114114 """Logit Strategy that returns detailed metrics for TensorBoard."""
115115
116+ def __init__ (
117+ self ,
118+ student_forward_fn : Callable [..., jax .Array ],
119+ teacher_forward_fn : Callable [..., jax .Array ],
120+ labels_fn : Callable [..., jax .Array ],
121+ temperature : float = 2.0 ,
122+ alpha : float = 0.5 ,
123+ beta_feature : float = 0.0 ,
124+ layer_indices : Optional [List [int ]] = None ,
125+ feature_loss_fn : Callable [[jax .Array , jax .Array ], jax .Array ] | None = None ,
126+ cosine_distance_axis : int | tuple [int , ...] = - 1 ,
127+ ):
128+ """Initializes the Combined strategy using tunix logit.LogitStrategy.
129+
130+ Args:
131+ student_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute student model outputs.
132+ teacher_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute teacher model outputs.
133+ labels_fn: Inherited from `logit.LogitStrategy`. Function to compute labels from model inputs.
134+ temperature: Inherited from `logit.LogitStrategy`. Temperature for softening probabilities (> 0).
135+ alpha: Inherited from `logit.LogitStrategy`. Weight to balance distillation loss and task loss (0.0 to 1.0).
136+ beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
137+ layer_indices: Layer indices to apply feature loss.
138+ feature_loss_fn: A function that takes two jax. Arrays (student_map,
139+ teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
140+ cosine_distance_axis: The axis to use for cosine distance computation if
141+ feature_loss_fn is not provided. Defaults to -1.
142+ """
143+ super ().__init__ (
144+ student_forward_fn = student_forward_fn ,
145+ teacher_forward_fn = teacher_forward_fn ,
146+ labels_fn = labels_fn ,
147+ temperature = temperature ,
148+ alpha = alpha ,
149+ )
150+ self .beta_feature = beta_feature
151+ self .layer_indices = jnp .array (layer_indices ) if layer_indices is not None else None
152+
153+ self .feature_loss_fn = feature_loss_fn
154+ if feature_loss_fn is None :
155+ self .feature_loss_fn = lambda student_features , teacher_features : jnp .mean (
156+ optax .cosine_distance (student_features , teacher_features , axis = cosine_distance_axis )
157+ )
158+
116159 def compute_loss (
117160 self ,
118161 student_output : jax .Array ,
@@ -123,8 +166,18 @@ def compute_loss(
123166 # Calculate Distillation Loss (KL Divergence)
124167 # Scale logits by temperature T for soft targets
125168 # We use explicit float32 casting for stability in loss calculation
126- s_logits = student_output .astype (jnp .float32 )
127- t_logits = teacher_output .astype (jnp .float32 )
169+ s_logits = student_output [0 ].astype (jnp .float32 )
170+ t_logits = teacher_output [0 ].astype (jnp .float32 )
171+
172+ # Shape: [num_layers, batch, seq, hidden_dim]
173+ s_features = student_output [- 1 ]
174+ t_features = teacher_output [- 1 ]
175+
176+ if (s_features is None or t_features is None ) and self .beta_feature > 0.0 :
177+ raise ValueError (
178+ "Features extracted from student or teacher model are None, but distill_beta > 0.0. "
179+ "Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
180+ )
128181
129182 log_student_probs_temp = jax .nn .log_softmax (s_logits / self .temperature , axis = - 1 )
130183 teacher_probs_temp = jax .nn .softmax (t_logits / self .temperature , axis = - 1 )
@@ -144,14 +197,31 @@ def compute_loss(
144197 teacher_hard_loss = jnp .mean (ce_loss_teacher )
145198
146199 # 3. Combine losses
147- total_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
200+ base_logit_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
201+
202+ feature_loss = 0.0
203+ if self .beta_feature > 0.0 :
204+
205+ if self .layer_indices is not None :
206+ # jnp.take slices along axis=0 (the layer dimension)
207+ s_features_sliced = jnp .take (s_features , self .layer_indices , axis = 0 )
208+ t_features_sliced = jnp .take (t_features , self .layer_indices , axis = 0 )
209+ else :
210+ s_features_sliced = s_features
211+ t_features_sliced = t_features
212+
213+ feature_loss = self .beta_feature * self .feature_loss_fn (s_features_sliced , t_features_sliced )
214+
215+ total_loss = base_logit_loss + feature_loss
148216
149217 # 4. Return Loss AND Metrics
150218 metrics = {
151219 "distill/soft_loss" : soft_loss ,
152220 "distill/hard_loss" : hard_loss ,
153221 "distill/kl_div" : jnp .mean (kl_div ),
154222 "distill/teacher_loss" : teacher_hard_loss ,
223+ "distill/out_proj_feature_loss" : feature_loss ,
224+ "distill/total_loss" : total_loss ,
155225 }
156226 return total_loss , metrics
157227
0 commit comments