@@ -314,6 +314,10 @@ class Checkpointing(BaseModel):
314314 True , description = "If True, saves a final checkpoint upon training completion."
315315 )
316316 enable_continuous_checkpointing : bool = Field (False , description = "If True, enables continuous checkpointing." )
317+ colocated_python_checkpointing : bool = Field (
318+ False ,
319+ description = "If True, enables checkpointing from remote TPU VMs instead of head node on pathways." ,
320+ )
317321
318322
319323class OrbaxStorage (BaseModel ):
@@ -599,7 +603,8 @@ class MoEGeneral(BaseModel):
599603 capacity_factor : float = Field (- 1.0 , description = "Expert capacity factor. If < 0, no token dropping." )
600604 load_balance_loss_weight : NonNegativeFloat = Field (0.0 , description = "Weight for the load balancing auxiliary loss." )
601605 use_custom_sort_vjp : bool = Field (
602- True , description = "Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
606+ True ,
607+ description = "Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul." ,
603608 )
604609 use_ring_of_experts : bool = Field (
605610 False ,
@@ -1003,7 +1008,8 @@ class GrainDataset(BaseModel):
10031008 grain_train_files : PathStr = Field ("" , description = "Path to Grain training files." )
10041009 grain_eval_files : PathStr = Field ("" , description = "Path to Grain evaluation files." )
10051010 grain_train_mixture_config_path : PathStr = Field (
1006- "" , description = "Path to a JSON file specifying the mixture weights for Grain training data."
1011+ "" ,
1012+ description = "Path to a JSON file specifying the mixture weights for Grain training data." ,
10071013 )
10081014 grain_file_type : str = Field ("arrayrecord" , description = "File type for Grain data." )
10091015 grain_worker_count : int = Field (1 , description = "Number of workers for Grain data loading." )
@@ -1049,10 +1055,12 @@ class Distillation(BaseModel):
10491055 # These dictionaries allow flexible configuration injection for Student/Teacher
10501056 # without needing to duplicate the entire MaxText schema here.
10511057 student_overrides : dict [str , Any ] = Field (
1052- default_factory = dict , description = "Overrides specific to the Student model (e.g., {'num_query_heads': 16})."
1058+ default_factory = dict ,
1059+ description = "Overrides specific to the Student model (e.g., {'num_query_heads': 16})." ,
10531060 )
10541061 teacher_overrides : dict [str , Any ] = Field (
1055- default_factory = dict , description = "Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})."
1062+ default_factory = dict ,
1063+ description = "Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})." ,
10561064 )
10571065
10581066 # --- Loss Params ---
@@ -1122,16 +1130,22 @@ class Optimizer(BaseModel):
11221130 )
11231131 learning_rate : NonNegativeFloat = Field (3.0e-5 , description = "The peak learning rate." )
11241132 lr_schedule_type : LearningRateScheduleType = Field (
1125- LearningRateScheduleType .COSINE , description = "The type of learning rate schedule to use."
1133+ LearningRateScheduleType .COSINE ,
1134+ description = "The type of learning rate schedule to use." ,
11261135 )
11271136 learning_rate_final_fraction : float = Field (
1128- 0.1 , description = "Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)."
1137+ 0.1 ,
1138+ description = "Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)." ,
11291139 )
11301140 wsd_decay_steps_fraction : float = Field (
1131- 0.1 , ge = 0.0 , le = 1.0 , description = "Fraction of total steps for decay phase in WSD schedule."
1141+ 0.1 ,
1142+ ge = 0.0 ,
1143+ le = 1.0 ,
1144+ description = "Fraction of total steps for decay phase in WSD schedule." ,
11321145 )
11331146 wsd_decay_style : WsdDecayStyle = Field (
1134- WsdDecayStyle .LINEAR , description = "The decay style for WSD schedule ('linear' or 'cosine')."
1147+ WsdDecayStyle .LINEAR ,
1148+ description = "The decay style for WSD schedule ('linear' or 'cosine')." ,
11351149 )
11361150 warmup_steps_fraction : float = Field (0.1 , ge = 0.0 , le = 1.0 , description = "Fraction of total steps for LR warmup." )
11371151 learning_rate_schedule_steps : int = Field (
@@ -1172,10 +1186,12 @@ class Muon(BaseModel):
11721186
11731187 muon_beta : float = Field (0.95 , description = "Decay rate for the exponentially weighted average of grads." )
11741188 muon_weight_decay : float = Field (
1175- 0 , description = "Strength of the weight decay regularization. This is multiplied with the learning rate."
1189+ 0 ,
1190+ description = "Strength of the weight decay regularization. This is multiplied with the learning rate." ,
11761191 )
11771192 muon_consistent_rms : None | float = Field (
1178- None , description = "If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)."
1193+ None ,
1194+ description = "If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)." ,
11791195 )
11801196
11811197
@@ -1552,7 +1568,8 @@ class RLHardware(BaseModel):
15521568 "than one model replica in rollout." ,
15531569 )
15541570 rollout_tensor_parallelism : int = Field (
1555- - 1 , description = "Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
1571+ - 1 ,
1572+ description = "Tensor parallelism per replica for rollout. If not specified, it will be auto-determined." ,
15561573 )
15571574
15581575
@@ -1567,7 +1584,8 @@ class VLLM(BaseModel):
15671584 max_num_seqs : Optional [int ] = Field (None , description = "Max number of sequences in vLLM." )
15681585 vllm_additional_config : dict [str , Any ] = Field (default_factory = dict , description = "Additional vLLM config options." )
15691586 vllm_hf_overrides : dict [str , Any ] = Field (
1570- default_factory = dict , description = "Overrides for HuggingFace model config for MaxText model."
1587+ default_factory = dict ,
1588+ description = "Overrides for HuggingFace model config for MaxText model." ,
15711589 )
15721590 vllm_hf_config_path : str = Field ("" , description = "Path to HuggingFace model config for MaxText model." )
15731591
@@ -1646,7 +1664,8 @@ class Engram(BaseModel):
16461664 engram_num_heads : int = Field (8 , description = "Number of heads dedicated to the Engram." )
16471665 engram_head_dim : int = Field (1280 , description = "Head dimension for heads." )
16481666 engram_vocab_bases : list [int ] = Field (
1649- default_factory = list , description = "List of minimum head vocab sizes for each n-gram order."
1667+ default_factory = list ,
1668+ description = "List of minimum head vocab sizes for each n-gram order." ,
16501669 )
16511670 engram_max_ngram_size : int = Field (3 , description = "The max 'n' in N-gram." )
16521671 engram_kernel_size : int = Field (4 , description = "Temporal window size for Engram convolution." )
@@ -1892,7 +1911,8 @@ class MaxTextConfig(
18921911
18931912 debug : Debug = Field (default_factory = Debug , description = "Configuration for debugging options." )
18941913 rl : RL = Field (
1895- default_factory = RL , description = "Configuration for RL algorithms like Group Relative Policy Optimization (GRPO)."
1914+ default_factory = RL ,
1915+ description = "Configuration for RL algorithms like Group Relative Policy Optimization (GRPO)." ,
18961916 )
18971917 model_config = ConfigDict (extra = "forbid" , protected_namespaces = ())
18981918
@@ -1941,7 +1961,11 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19411961 filter (
19421962 os .path .exists ,
19431963 (
1944- os .path .join (MAXTEXT_ASSETS_ROOT , "tokenizers" , os .path .basename (tokenizer_path )),
1964+ os .path .join (
1965+ MAXTEXT_ASSETS_ROOT ,
1966+ "tokenizers" ,
1967+ os .path .basename (tokenizer_path ),
1968+ ),
19451969 os .path .join (MAXTEXT_ASSETS_ROOT , "tokenizers" , tokenizer_path ),
19461970 ),
19471971 ),
@@ -2093,7 +2117,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20932117 self .global_batch_size_to_eval_on ,
20942118 self .micro_batch_size_to_eval_on ,
20952119 ) = calculate_global_batch_sizes (
2096- self .eval_per_device_batch_size , self .expansion_factor_real_data , self .num_target_devices , 1
2120+ self .eval_per_device_batch_size ,
2121+ self .expansion_factor_real_data ,
2122+ self .num_target_devices ,
2123+ 1 ,
20972124 )
20982125
20992126 # Calculate ramp-up batch size parameters if enabled.
@@ -2262,6 +2289,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22622289 raise ValueError ("`local_checkpoint_period` must be > 0 for multi-tier checkpointing." )
22632290 if self .multi_tier_checkpointing_backup_interval_minutes <= 0 :
22642291 raise ValueError ("`multi_tier_checkpointing_backup_interval_minutes` must be > 0." )
2292+ if self .colocated_python_checkpointing and not self .enable_single_controller :
2293+ raise ValueError ("`colocated_python_checkpointing` is only supported with `enable_single_controller` set to True." )
22652294 if self .enable_emergency_checkpoint :
22662295 if not self .local_checkpoint_directory :
22672296 raise ValueError ("`local_checkpoint_directory` must be set for emergency checkpointing." )
@@ -2423,7 +2452,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24232452 raise ValueError ("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path" )
24242453 if self .eval_interval > 0 and not self .grain_eval_files :
24252454 raise ValueError ("Please specify grain_eval_files or set eval_interval to <=0." )
2426- if self .tokenizer_type not in (TokenizerType .SENTENCEPIECE , TokenizerType .HUGGINGFACE ):
2455+ if self .tokenizer_type not in (
2456+ TokenizerType .SENTENCEPIECE ,
2457+ TokenizerType .HUGGINGFACE ,
2458+ ):
24272459 raise ValueError (
24282460 f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got { self .tokenizer_type } "
24292461 )
0 commit comments