Skip to content

Commit f1fc688

Browse files
Merge pull request #3185 from hx89:hx89/checkpoint-te-quantizations
PiperOrigin-RevId: 874746242
2 parents f33b16c + 807bbc3 commit f1fc688

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import functools
2020
from typing import Any
21+
import warnings
2122

2223
from flax import linen as nn
2324
from flax import nnx
@@ -283,7 +284,7 @@ def setup(self):
283284
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
284285
)
285286

286-
def minimal_policy(self, with_context=False):
287+
def minimal_policy(self, with_context=False, with_quantization=False):
287288
"""Helper for creating minimal checkpoint policies."""
288289
names = [
289290
"query_proj",
@@ -298,6 +299,8 @@ def minimal_policy(self, with_context=False):
298299
]
299300
if with_context:
300301
names.append("context")
302+
if with_quantization:
303+
names.append("quantization")
301304
return jax.checkpoint_policies.save_only_these_names(*names)
302305

303306
def get_remat_policy(self):
@@ -314,6 +317,14 @@ def get_remat_policy(self):
314317
elif cfg.remat_policy == "minimal":
315318
# save all except context
316319
policy = self.minimal_policy()
320+
elif cfg.remat_policy == "minimal_with_quantization":
321+
if cfg.scan_layers:
322+
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
323+
policy = self.minimal_policy(with_context=False, with_quantization=True)
324+
elif cfg.remat_policy == "minimal_with_context_and_quantization":
325+
if cfg.scan_layers:
326+
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
327+
policy = self.minimal_policy(with_context=True, with_quantization=True)
317328
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
318329
policy = jax.checkpoint_policies.save_only_these_names(
319330
"query_proj",

src/maxtext/layers/quantizations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,10 @@ class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase):
805805
def generate_quantizer_set(self, postfix: str = ""):
806806
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
807807
return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args
808-
postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=fp8_recipe
808+
postfix=postfix,
809+
variable_collection=OVERWRITE_WITH_GRADIENT,
810+
quantization_checkpoint_name="quantization",
811+
fp8_recipe=fp8_recipe
809812
)
810813

811814
@nn.compact

0 commit comments

Comments
 (0)