@@ -784,6 +784,7 @@ class LayoutAndSharding(BaseModel):
784784class DcnParallelism (BaseModel ):
785785 """Parallelism dimensions across the DCN (Data Center Network)."""
786786
787+ dcn_diloco_parallelism : int = Field (1 , description = "DCN axis for Diloco parallelism." )
787788 dcn_data_parallelism : int = Field (- 1 , description = "DCN axis for data parallelism." )
788789 dcn_fsdp_parallelism : int = Field (1 , description = "DCN axis for FSDP." )
789790 dcn_fsdp_transpose_parallelism : int = Field (1 , description = "DCN axis for FSDP transpose." )
@@ -803,6 +804,7 @@ class DcnParallelism(BaseModel):
803804class IciParallelism (BaseModel ):
804805 """Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
805806
807+ ici_diloco_parallelism : int = Field (1 , description = "ICI axis for Diloco parallelism." )
806808 ici_data_parallelism : int = Field (1 , description = "ICI axis for data parallelism." )
807809 ici_fsdp_parallelism : int = Field (- 1 , description = "ICI axis for FSDP." )
808810 ici_fsdp_transpose_parallelism : int = Field (1 , description = "ICI axis for FSDP transpose." )
@@ -1082,6 +1084,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
10821084 sinkhorn_iterations : PositiveInt = Field (20 , description = "The number of iterations for the Sinkhorn-Knopp algorithm." )
10831085
10841086
1087+ class DilocoParams (BaseModel ):
1088+ """Diloco Hyperparameters"""
1089+
1090+ enable_diloco : bool = Field (False , description = "Enable Diloco parallelism" )
1091+ diloco_sync_period : int = Field (36 , description = "Diloco sync period." )
1092+ diloco_outer_lr : float = Field (0.3 , description = "learning rate for outer optimizer." )
1093+ diloco_outer_momentum : float = Field (0.9 , description = "momentum for outer optimizer." )
1094+
1095+
10851096class Optimizer (BaseModel ):
10861097 """Configuration for the optimizer and learning rate schedule."""
10871098
@@ -1632,6 +1643,11 @@ class DerivedValues(BaseModel):
16321643 description = "Effective number of query heads, scaled by `global_parameter_scale`." ,
16331644 )
16341645
1646+ num_diloco_replicas : None | int = Field (
1647+ None ,
1648+ description = "The number of diloco replicas, derived from ICI and DCN values." ,
1649+ )
1650+
16351651 ici_parallelism : None | list [int ] = Field (
16361652 None ,
16371653 description = "Aggregated list of all ICI parallelism values for legacy compatibility." ,
@@ -1779,6 +1795,7 @@ class MaxTextConfig(
17791795 RematAndOffload ,
17801796 TrainingLoop ,
17811797 ManifoldConstrainedHyperConnections ,
1798+ DilocoParams ,
17821799 Optimizer ,
17831800 AdamW ,
17841801 Muon ,
@@ -2375,6 +2392,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23752392 # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
23762393 if self .using_pipeline_parallelism and self .mesh_axes and self .mesh_axes [0 ] == "stage" :
23772394 self .ici_parallelism = [
2395+ self .ici_diloco_parallelism ,
23782396 self .ici_pipeline_parallelism ,
23792397 self .ici_data_parallelism ,
23802398 self .ici_fsdp_parallelism ,
@@ -2389,6 +2407,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23892407 self .ici_autoregressive_parallelism ,
23902408 ]
23912409 self .dcn_parallelism = [
2410+ self .dcn_diloco_parallelism ,
23922411 self .dcn_pipeline_parallelism ,
23932412 self .dcn_data_parallelism ,
23942413 self .dcn_fsdp_parallelism ,
@@ -2404,6 +2423,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24042423 ]
24052424 else :
24062425 ici_map = {
2426+ "diloco" : self .ici_diloco_parallelism ,
24072427 "data" : self .ici_data_parallelism ,
24082428 "stage" : self .ici_pipeline_parallelism ,
24092429 "fsdp" : self .ici_fsdp_parallelism ,
@@ -2422,6 +2442,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24222442 self .ici_parallelism = [ici_map [axis ] for axis in self .mesh_axes ]
24232443
24242444 dcn_map = {
2445+ "diloco" : self .dcn_diloco_parallelism ,
24252446 "data" : self .dcn_data_parallelism ,
24262447 "stage" : self .dcn_pipeline_parallelism ,
24272448 "fsdp" : self .dcn_fsdp_parallelism ,
@@ -2439,6 +2460,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24392460 }
24402461 self .dcn_parallelism = [dcn_map [axis ] for axis in self .mesh_axes ]
24412462
2463+ # Diloco params
2464+ self .num_diloco_replicas = int (self .ici_diloco_parallelism * self .dcn_diloco_parallelism )
2465+
24422466 # Final string-to-enum conversions if they haven't been coerced by pydantic yet.
24432467 if isinstance (self .decoder_block , str ):
24442468 self .decoder_block = DecoderBlockType (self .decoder_block .lower ())
0 commit comments