Skip to content

Commit 94f24a0

Browse files
committed
add activation checkpoint for batch_split
update clean up add attention_out for attention, attention_mla
1 parent 00eb74e commit 94f24a0

5 files changed

Lines changed: 18 additions & 0 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ qkv_proj: 'remat'
310310
out_proj: 'remat'
311311
mla_q: 'remat'
312312
mla_kv: 'remat'
313+
attention_out: 'remat'
313314

314315
optimizer_memory_host_offload: False
315316
parameter_memory_host_offload: False

src/MaxText/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,11 @@ class RematAndOffload(BaseModel):
872872
RematLocation.REMAT,
873873
description="Remat policy for the mla's key and value projection.",
874874
)
875+
attention_out: RematLocation = Field(
876+
RematLocation.REMAT,
877+
description="Remat policy for the attention output.",
878+
)
879+
875880
optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
876881
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")
877882

@@ -2060,6 +2065,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20602065
"mla_kv",
20612066
"mla_q",
20622067
"qkv_proj",
2068+
"attention_out",
20632069
"out_proj",
20642070
]
20652071
self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"]

src/MaxText/layers/attention_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def __call__(
10381038
# Pass the index_mask to the Attention Op
10391039
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
10401040

1041+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
10411042
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
10421043
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
10431044
else:

src/MaxText/layers/attentions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,7 @@ def __call__(
11321132
bidirectional_mask,
11331133
self.sinks,
11341134
)
1135+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
11351136
if model_mode == MODEL_MODE_PREFILL:
11361137
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
11371138
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def mla(
336336
qk_nope_head_dim=qk_nope_head_dim,
337337
mscale=mscale,
338338
)
339+
query = jax.ad_checkpoint.checkpoint_name(query, "query_proj")
339340
key, value = kv_projection(
340341
inputs,
341342
positions,
@@ -355,6 +356,8 @@ def mla(
355356
qk_nope_head_dim=qk_nope_head_dim,
356357
num_query_heads=num_query_heads,
357358
)
359+
key = jax.ad_checkpoint.checkpoint_name(key, "key_proj")
360+
value = jax.ad_checkpoint.checkpoint_name(value, "value_proj")
358361
out = attention_op_fn(
359362
query,
360363
key,
@@ -363,7 +366,9 @@ def mla(
363366
model_mode,
364367
cached_values=[None, None],
365368
)
369+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
366370
out = dot(out, out_weights, axes=2)
371+
out = jax.ad_checkpoint.checkpoint_name(out, "out_proj")
367372
return out
368373

369374

@@ -402,6 +407,7 @@ def query_projection(
402407
epsilon=epsilon,
403408
dtype=dtype,
404409
)
410+
low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q")
405411
q = dot(low_rank_q, wq_b_weights)
406412

407413
# Split into non-positional and rotary parts.
@@ -451,6 +457,7 @@ def kv_projection(
451457
epsilon=kv_norm_epsilon,
452458
dtype=dtype,
453459
)
460+
low_rank_main = jax.ad_checkpoint.checkpoint_name(low_rank_main, "mla_kv")
454461
key_rope = jnp.expand_dims(low_rank_rope, axis=2)
455462
key_rope = yarn(
456463
key_rope,
@@ -690,6 +697,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size,
690697
)
691698
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
692699
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size)
700+
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
701+
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
693702
intermediate_layer = jax.nn.silu(layer_w0) * layer_w1
694703
intermediate_layer *= weights[:, None]
695704
return gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)

0 commit comments

Comments
 (0)