Skip to content

Commit 5a4a9c3

Browse files
Merge pull request #3218 from AI-Hypercomputer:cos_loss
PiperOrigin-RevId: 875535755
2 parents 44039d8 + 5d53a91 commit 5a4a9c3

9 files changed

Lines changed: 244 additions & 167 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,3 +1122,11 @@ engram_vocab_bases: []
11221122
engram_kernel_size: 4
11231123
# The seed for Engram hash mapping.
11241124
engram_seed: 0
1125+
1126+
##### Distillation parameters
1127+
distill_alpha: 0.5
1128+
distill_temperature: 1.0
1129+
# distill_beta is used for cosine similarity loss between intermediate activataitions of out_proj in teacher/student models.
1130+
# 0.0 value disables this feature.
1131+
distill_beta: 0.0
1132+
distill_layer_indices: None

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,8 @@ class Distillation(BaseModel):
10571057
# --- Loss Params ---
10581058
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
10591059
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
1060+
distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable")
1061+
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
10601062

10611063

10621064
class TrainingLoop(BaseModel):
@@ -2010,6 +2012,13 @@ def validate_and_set_hlo_dump_defaults():
20102012
# Validate and initiate hlo dump related configs
20112013
validate_and_set_hlo_dump_defaults()
20122014

2015+
# Validate nnx sow incompatibility
2016+
if self.distill_beta > 0.0:
2017+
if not self.scan_layers:
2018+
raise ValueError("a value of self.distill_beta > 0.0 requires self.scan_layers = True")
2019+
if not self.enable_nnx:
2020+
raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True")
2021+
20132022
# D. CALCULATE MODEL DIMENSIONS from global_parameter_scale
20142023
# This allows scaling the model size up or down easily with a single power-of-two factor.
20152024
emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale)

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,5 +1162,7 @@ def __call__(
11621162
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
11631163
out = out * jax.nn.sigmoid(gate)
11641164
out = self.out_projection(out, out_sharding=out_sharding)
1165+
if self.config.distill_beta > 0.0:
1166+
self.sow(nnx.Intermediate, "out_projection_activations", out)
11651167
out = checkpoint_name(out, "out_proj")
11661168
return out, kv_cache

src/maxtext/models/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,12 @@ def __call__(
477477
if audio_embeddings is not None:
478478
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
479479

480+
mutable_collections = []
481+
if self.config.record_internal_nn_metrics:
482+
mutable_collections.append("intermediates")
483+
if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections:
484+
mutable_collections.append("intermediates")
485+
480486
logits, hidden_state, kv_caches = self.decoder(
481487
shared_embedding=self.token_embedder,
482488
decoder_input_tokens=decoder_input_tokens,
@@ -495,6 +501,7 @@ def __call__(
495501
kv_caches=kv_caches,
496502
attention_metadata=attention_metadata,
497503
deepstack_visual_embeds=deepstack_visual_embeds,
504+
mutable=mutable_collections,
498505
)
499506

500507
# Materialize hidden state when vocab tiling is enabled

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21-
from typing import Any, Iterator
21+
from typing import Any, Iterator, Optional, List, Callable
2222

2323
import flax
2424
from flax import nnx
@@ -110,9 +110,52 @@ def __next__(self) -> MaxTextTrainingInput:
110110
# -----------------------------------------------------------------------------
111111
# Distillation Strategy
112112
# -----------------------------------------------------------------------------
113-
class MonitoredLogitStrategy(logit.LogitStrategy):
113+
class CombinedDistillationStrategy(logit.LogitStrategy):
114114
"""Logit Strategy that returns detailed metrics for TensorBoard."""
115115

116+
def __init__(
117+
self,
118+
student_forward_fn: Callable[..., jax.Array],
119+
teacher_forward_fn: Callable[..., jax.Array],
120+
labels_fn: Callable[..., jax.Array],
121+
temperature: float = 2.0,
122+
alpha: float = 0.5,
123+
beta_feature: float = 0.0,
124+
layer_indices: Optional[List[int]] = None,
125+
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
126+
cosine_distance_axis: int | tuple[int, ...] = -1,
127+
):
128+
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
129+
130+
Args:
131+
student_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute student model outputs.
132+
teacher_forward_fn: Inherited from `logit.LogitStrategy`. Function to compute teacher model outputs.
133+
labels_fn: Inherited from `logit.LogitStrategy`. Function to compute labels from model inputs.
134+
temperature: Inherited from `logit.LogitStrategy`. Temperature for softening probabilities (> 0).
135+
alpha: Inherited from `logit.LogitStrategy`. Weight to balance distillation loss and task loss (0.0 to 1.0).
136+
beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
137+
layer_indices: Layer indices to apply feature loss.
138+
feature_loss_fn: A function that takes two jax. Arrays (student_map,
139+
teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
140+
cosine_distance_axis: The axis to use for cosine distance computation if
141+
feature_loss_fn is not provided. Defaults to -1.
142+
"""
143+
super().__init__(
144+
student_forward_fn=student_forward_fn,
145+
teacher_forward_fn=teacher_forward_fn,
146+
labels_fn=labels_fn,
147+
temperature=temperature,
148+
alpha=alpha,
149+
)
150+
self.beta_feature = beta_feature
151+
self.layer_indices = jnp.array(layer_indices) if layer_indices is not None else None
152+
153+
self.feature_loss_fn = feature_loss_fn
154+
if feature_loss_fn is None:
155+
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
156+
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
157+
)
158+
116159
def compute_loss(
117160
self,
118161
student_output: jax.Array,
@@ -123,8 +166,18 @@ def compute_loss(
123166
# Calculate Distillation Loss (KL Divergence)
124167
# Scale logits by temperature T for soft targets
125168
# We use explicit float32 casting for stability in loss calculation
126-
s_logits = student_output.astype(jnp.float32)
127-
t_logits = teacher_output.astype(jnp.float32)
169+
s_logits = student_output[0].astype(jnp.float32)
170+
t_logits = teacher_output[0].astype(jnp.float32)
171+
172+
# Shape: [num_layers, batch, seq, hidden_dim]
173+
s_features = student_output[-1]
174+
t_features = teacher_output[-1]
175+
176+
if (s_features is None or t_features is None) and self.beta_feature > 0.0:
177+
raise ValueError(
178+
"Features extracted from student or teacher model are None, but distill_beta > 0.0. "
179+
"Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
180+
)
128181

129182
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
130183
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
@@ -144,14 +197,31 @@ def compute_loss(
144197
teacher_hard_loss = jnp.mean(ce_loss_teacher)
145198

146199
# 3. Combine losses
147-
total_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
200+
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
201+
202+
feature_loss = 0.0
203+
if self.beta_feature > 0.0:
204+
205+
if self.layer_indices is not None:
206+
# jnp.take slices along axis=0 (the layer dimension)
207+
s_features_sliced = jnp.take(s_features, self.layer_indices, axis=0)
208+
t_features_sliced = jnp.take(t_features, self.layer_indices, axis=0)
209+
else:
210+
s_features_sliced = s_features
211+
t_features_sliced = t_features
212+
213+
feature_loss = self.beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced)
214+
215+
total_loss = base_logit_loss + feature_loss
148216

149217
# 4. Return Loss AND Metrics
150218
metrics = {
151219
"distill/soft_loss": soft_loss,
152220
"distill/hard_loss": hard_loss,
153221
"distill/kl_div": jnp.mean(kl_div),
154222
"distill/teacher_loss": teacher_hard_loss,
223+
"distill/out_proj_feature_loss": feature_loss,
224+
"distill/total_loss": total_loss,
155225
}
156226
return total_loss, metrics
157227

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,16 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
135135
del kwargs # Unused
136136
del attention_mask # Unused
137137
del cache # Unused
138-
139138
logits = model(
140139
decoder_input_tokens=input_tokens,
141140
decoder_positions=positions,
142141
decoder_segment_ids=decoder_segment_ids,
143142
enable_dropout=config.enable_dropout,
144143
)
145-
return logits
144+
hidden_features = None
145+
if config.distill_beta > 0.0:
146+
hidden_features = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True)
147+
return logits, hidden_features
146148

147149
return model_forward_fn
148150

@@ -356,14 +358,18 @@ def labels_fn(targets, **kwargs):
356358
teacher_forward_fn = create_forward_fn(teacher_config)
357359

358360
# Use Monitored strategy from Utils
359-
strategy = distillation_utils.MonitoredLogitStrategy(
361+
strategy = distillation_utils.CombinedDistillationStrategy(
360362
student_forward_fn=student_forward_fn,
361363
teacher_forward_fn=teacher_forward_fn,
362364
labels_fn=labels_fn,
363365
temperature=student_config.distill_temperature,
364366
alpha=student_config.distill_alpha,
367+
beta_feature=student_config.distill_beta,
368+
layer_indices=student_config.distill_layer_indices,
365369
)
366370

371+
student_model, teacher_model = strategy.pre_process_models(student_model, teacher_model)
372+
367373
# 4. Optimizer & Config
368374
optimizer = get_distillation_optimizer(student_config, student_config.steps)
369375

src/maxtext/utils/maxtext_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,34 @@ def get_nested_value(dictionary, nested_key, default=None):
897897
return current_level
898898

899899

900+
def get_intermediate_value(model, nested_key, default=None, clear=False):
901+
"""
902+
Retrieves an intermediate value from an NNX model. This functions has context about
903+
where the intermediate value is located.
904+
905+
Args:
906+
model: The NNX model.
907+
nested_key: A string representing the nested key, e.g., hidden_states_norm_out
908+
default: The value to return if the nested key is not found.
909+
clear: Clears the intermediate value from the model.
910+
911+
Returns:
912+
The value associated with the nested key, or the default value if not found.
913+
"""
914+
intermediate_value = default
915+
match nested_key:
916+
case "out_projection_activations":
917+
if nested_key in model.decoder.layers["self_attention"]:
918+
intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1]
919+
if clear:
920+
del model.decoder.layers["self_attention"][nested_key]
921+
case _:
922+
# Default case to handle any unknown nested keys
923+
raise ValueError(f"Incorrect nested_key: {nested_key}")
924+
925+
return intermediate_value
926+
927+
900928
def update_state_param(state, target_path, value):
901929
"""
902930
Updates a specific parameter in state.params at the given path.

0 commit comments

Comments
 (0)