Skip to content

Commit a4f874d

Browse files
Merge pull request #3078 from AI-Hypercomputer:sujinesh/colocated_python_checkpointing
PiperOrigin-RevId: 876389548
2 parents 12a7777 + 35a666e commit a4f874d

8 files changed

Lines changed: 80 additions & 21 deletions

File tree

benchmarks/recipes/user_configs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class UserConfig:
5353
zone: str = "us-east5-b"
5454
device_type: str = "v6e-256"
5555
priority: str = "medium"
56+
base_output_directory: str = None
5657

5758
# Images for env
5859
server_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
@@ -97,7 +98,7 @@ def __post_init__(self):
9798
self.worker_flags,
9899
)
99100
self.headless_workload_name = f"{self.user[:3]}-headless"
100-
self.base_output_directory = f"gs://{self.user}-{self.region}/{self.user}-"
101+
self.base_output_directory = self.base_output_directory or f"gs://{self.user}-{self.region}/{self.user}-"
101102

102103
device_base_type = self.device_type.split("-", maxsplit=1)[0]
103104
self.models = build_user_models(
@@ -124,4 +125,7 @@ def __post_init__(self):
124125
selected_model_framework=["pathways"],
125126
selected_model_names=["llama3_1_8b_8192"],
126127
priority="medium",
128+
base_output_directory=None, # GCS Bucket path
129+
# Optional parameters, useful for single controller data loading optimizations
130+
# proxy_flags="--sidecar_name=external",
127131
)

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ tensorflow
4141
tiktoken
4242
tokamax
4343
transformers
44+
uvloop
4445
qwix
4546
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
4647
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ opt-einsum>=3.4.0
155155
optax>=0.2.6
156156
optree>=0.18.0
157157
optype>=0.14.0
158-
orbax-checkpoint>=0.11.28
158+
orbax-checkpoint>=0.11.33
159159
packaging>=25.0
160160
pandas>=2.3.3
161161
parameterized>=0.9.0
@@ -245,6 +245,7 @@ tzdata>=2025.2
245245
uritemplate>=4.2.0
246246
urllib3>=2.5.0
247247
uvicorn>=0.38.0
248+
uvloop>=0.19.0
248249
virtualenv>=20.35.4
249250
wadler-lindig>=0.1.7
250251
websockets>=15.0.1

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ opt-einsum>=3.4.0
149149
optax>=0.2.6
150150
optree>=0.18.0
151151
optype>=0.14.0
152-
orbax-checkpoint>=0.11.28
152+
orbax-checkpoint>=0.11.33
153153
packaging>=25.0
154154
pandas>=2.3.3
155155
parameterized>=0.9.0
@@ -237,6 +237,7 @@ tzdata>=2025.2
237237
uritemplate>=4.2.0
238238
urllib3>=2.5.0
239239
uvicorn>=0.38.0
240+
uvloop>=0.19.0
240241
virtualenv>=20.35.4
241242
wadler-lindig>=0.1.7
242243
websockets>=15.0.1

src/maxtext/common/checkpointing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def create_orbax_checkpoint_manager(
217217
enable_continuous_checkpointing: bool = False,
218218
max_num_checkpoints_to_keep: int = 10,
219219
checkpoint_storage_concurrent_gb: int = 96,
220+
enable_single_controller: bool = False,
221+
colocated_python_checkpointing: bool = False,
222+
enable_single_replica_ckpt_restoring: bool = False,
220223
):
221224
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
222225
if not enable_checkpointing:
@@ -269,6 +272,17 @@ def create_orbax_checkpoint_manager(
269272
logger=orbax_logger,
270273
)
271274

275+
# Use Colocated Python checkpointing optimization (Single Controller only).
276+
if enable_single_controller and colocated_python_checkpointing:
277+
max_logging.log("Registering colocated python array handler")
278+
checkpointing_impl = ocp.pathways.CheckpointingImpl.from_options(
279+
use_colocated_python=True,
280+
)
281+
ocp.pathways.register_type_handlers(
282+
use_single_replica_array_handler=enable_single_replica_ckpt_restoring,
283+
checkpointing_impl=checkpointing_impl,
284+
)
285+
272286
max_logging.log("Checkpoint manager created!")
273287
return manager
274288

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ enable_orbax_v1: False
7676
checkpoint_conversion_fn: none
7777
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
7878
source_checkpoint_layout: "orbax"
79+
80+
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
81+
colocated_python_checkpointing: False
7982
############################### end checkpointing ##################################
8083

8184

src/maxtext/configs/types.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ class Checkpointing(BaseModel):
314314
True, description="If True, saves a final checkpoint upon training completion."
315315
)
316316
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
317+
colocated_python_checkpointing: bool = Field(
318+
False,
319+
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
320+
)
317321

318322

319323
class OrbaxStorage(BaseModel):
@@ -599,7 +603,8 @@ class MoEGeneral(BaseModel):
599603
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
600604
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
601605
use_custom_sort_vjp: bool = Field(
602-
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
606+
True,
607+
description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul.",
603608
)
604609
use_ring_of_experts: bool = Field(
605610
False,
@@ -1003,7 +1008,8 @@ class GrainDataset(BaseModel):
10031008
grain_train_files: PathStr = Field("", description="Path to Grain training files.")
10041009
grain_eval_files: PathStr = Field("", description="Path to Grain evaluation files.")
10051010
grain_train_mixture_config_path: PathStr = Field(
1006-
"", description="Path to a JSON file specifying the mixture weights for Grain training data."
1011+
"",
1012+
description="Path to a JSON file specifying the mixture weights for Grain training data.",
10071013
)
10081014
grain_file_type: str = Field("arrayrecord", description="File type for Grain data.")
10091015
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
@@ -1049,10 +1055,12 @@ class Distillation(BaseModel):
10491055
# These dictionaries allow flexible configuration injection for Student/Teacher
10501056
# without needing to duplicate the entire MaxText schema here.
10511057
student_overrides: dict[str, Any] = Field(
1052-
default_factory=dict, description="Overrides specific to the Student model (e.g., {'num_query_heads': 16})."
1058+
default_factory=dict,
1059+
description="Overrides specific to the Student model (e.g., {'num_query_heads': 16}).",
10531060
)
10541061
teacher_overrides: dict[str, Any] = Field(
1055-
default_factory=dict, description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})."
1062+
default_factory=dict,
1063+
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10561064
)
10571065

10581066
# --- Loss Params ---
@@ -1122,16 +1130,22 @@ class Optimizer(BaseModel):
11221130
)
11231131
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
11241132
lr_schedule_type: LearningRateScheduleType = Field(
1125-
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
1133+
LearningRateScheduleType.COSINE,
1134+
description="The type of learning rate schedule to use.",
11261135
)
11271136
learning_rate_final_fraction: float = Field(
1128-
0.1, description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)."
1137+
0.1,
1138+
description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules).",
11291139
)
11301140
wsd_decay_steps_fraction: float = Field(
1131-
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
1141+
0.1,
1142+
ge=0.0,
1143+
le=1.0,
1144+
description="Fraction of total steps for decay phase in WSD schedule.",
11321145
)
11331146
wsd_decay_style: WsdDecayStyle = Field(
1134-
WsdDecayStyle.LINEAR, description="The decay style for WSD schedule ('linear' or 'cosine')."
1147+
WsdDecayStyle.LINEAR,
1148+
description="The decay style for WSD schedule ('linear' or 'cosine').",
11351149
)
11361150
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
11371151
learning_rate_schedule_steps: int = Field(
@@ -1172,10 +1186,12 @@ class Muon(BaseModel):
11721186

11731187
muon_beta: float = Field(0.95, description="Decay rate for the exponentially weighted average of grads.")
11741188
muon_weight_decay: float = Field(
1175-
0, description="Strength of the weight decay regularization. This is multiplied with the learning rate."
1189+
0,
1190+
description="Strength of the weight decay regularization. This is multiplied with the learning rate.",
11761191
)
11771192
muon_consistent_rms: None | float = Field(
1178-
None, description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)."
1193+
None,
1194+
description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).",
11791195
)
11801196

11811197

@@ -1552,7 +1568,8 @@ class RLHardware(BaseModel):
15521568
"than one model replica in rollout.",
15531569
)
15541570
rollout_tensor_parallelism: int = Field(
1555-
-1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
1571+
-1,
1572+
description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.",
15561573
)
15571574

15581575

@@ -1567,7 +1584,8 @@ class VLLM(BaseModel):
15671584
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
15681585
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
15691586
vllm_hf_overrides: dict[str, Any] = Field(
1570-
default_factory=dict, description="Overrides for HuggingFace model config for MaxText model."
1587+
default_factory=dict,
1588+
description="Overrides for HuggingFace model config for MaxText model.",
15711589
)
15721590
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
15731591

@@ -1646,7 +1664,8 @@ class Engram(BaseModel):
16461664
engram_num_heads: int = Field(8, description="Number of heads dedicated to the Engram.")
16471665
engram_head_dim: int = Field(1280, description="Head dimension for heads.")
16481666
engram_vocab_bases: list[int] = Field(
1649-
default_factory=list, description="List of minimum head vocab sizes for each n-gram order."
1667+
default_factory=list,
1668+
description="List of minimum head vocab sizes for each n-gram order.",
16501669
)
16511670
engram_max_ngram_size: int = Field(3, description="The max 'n' in N-gram.")
16521671
engram_kernel_size: int = Field(4, description="Temporal window size for Engram convolution.")
@@ -1892,7 +1911,8 @@ class MaxTextConfig(
18921911

18931912
debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.")
18941913
rl: RL = Field(
1895-
default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO)."
1914+
default_factory=RL,
1915+
description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).",
18961916
)
18971917
model_config = ConfigDict(extra="forbid", protected_namespaces=())
18981918

@@ -1941,7 +1961,11 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19411961
filter(
19421962
os.path.exists,
19431963
(
1944-
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", os.path.basename(tokenizer_path)),
1964+
os.path.join(
1965+
MAXTEXT_ASSETS_ROOT,
1966+
"tokenizers",
1967+
os.path.basename(tokenizer_path),
1968+
),
19451969
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", tokenizer_path),
19461970
),
19471971
),
@@ -2093,7 +2117,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20932117
self.global_batch_size_to_eval_on,
20942118
self.micro_batch_size_to_eval_on,
20952119
) = calculate_global_batch_sizes(
2096-
self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1
2120+
self.eval_per_device_batch_size,
2121+
self.expansion_factor_real_data,
2122+
self.num_target_devices,
2123+
1,
20972124
)
20982125

20992126
# Calculate ramp-up batch size parameters if enabled.
@@ -2262,6 +2289,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22622289
raise ValueError("`local_checkpoint_period` must be > 0 for multi-tier checkpointing.")
22632290
if self.multi_tier_checkpointing_backup_interval_minutes <= 0:
22642291
raise ValueError("`multi_tier_checkpointing_backup_interval_minutes` must be > 0.")
2292+
if self.colocated_python_checkpointing and not self.enable_single_controller:
2293+
raise ValueError("`colocated_python_checkpointing` is only supported with `enable_single_controller` set to True.")
22652294
if self.enable_emergency_checkpoint:
22662295
if not self.local_checkpoint_directory:
22672296
raise ValueError("`local_checkpoint_directory` must be set for emergency checkpointing.")
@@ -2423,7 +2452,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24232452
raise ValueError("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path")
24242453
if self.eval_interval > 0 and not self.grain_eval_files:
24252454
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
2426-
if self.tokenizer_type not in (TokenizerType.SENTENCEPIECE, TokenizerType.HUGGINGFACE):
2455+
if self.tokenizer_type not in (
2456+
TokenizerType.SENTENCEPIECE,
2457+
TokenizerType.HUGGINGFACE,
2458+
):
24272459
raise ValueError(
24282460
f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got {self.tokenizer_type}"
24292461
)

src/maxtext/utils/train_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def create_training_tools(config, model, mesh):
6161
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
6262
use_ocdbt = config.checkpoint_storage_use_ocdbt
6363
use_zarr3 = config.checkpoint_storage_use_zarr3
64-
if config.enable_single_controller:
64+
if config.enable_single_controller and not config.colocated_python_checkpointing:
6565
use_ocdbt, use_zarr3 = False, False
6666

6767
checkpoint_dir = ""
@@ -79,6 +79,9 @@ def create_training_tools(config, model, mesh):
7979
config.enable_continuous_checkpointing,
8080
config.max_num_checkpoints_to_keep,
8181
config.checkpoint_storage_concurrent_gb,
82+
config.enable_single_controller,
83+
config.colocated_python_checkpointing,
84+
config.enable_single_replica_ckpt_restoring,
8285
)
8386

8487
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)