Skip to content

Commit 1683fb7

Browse files
committed
diloco trainer
1 parent 570ee04 commit 1683fb7

33 files changed

Lines changed: 3782 additions & 20 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax
78
flax
89
gcsfs
910
google-api-python-client

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dill>=0.4.0
4040
distlib>=0.4.0
4141
dm-tree>=0.1.9
4242
docstring-parser>=0.17.0
43+
drjax>=0.1.4
4344
editdistance>=0.8.1
4445
einops>=0.8.1
4546
einshape>=1.0

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ dill>=0.4.0
4141
distlib>=0.4.0
4242
dm-tree>=0.1.9
4343
docstring-parser>=0.17.0
44+
drjax>=0.1.4
4445
editdistance>=0.8.1
4546
einops>=0.8.1
4647
einshape>=1.0

dependencies/requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax>=0.1.4
78
flax
89
gcsfs
910
google-api-python-client
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"architectures": [
3+
"MaxTextForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"auto_map": {
8+
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
9+
"AutoModel": "modeling_deepseek.DeepseekV3Model",
10+
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
11+
},
12+
"bos_token_id": 0,
13+
"eos_token_id": 1,
14+
"ep_size": 1,
15+
"first_k_dense_replace": 3,
16+
"hidden_act": "silu",
17+
"hidden_size": 7168,
18+
"initializer_range": 0.02,
19+
"intermediate_size": 18432,
20+
"kv_lora_rank": 512,
21+
"max_position_embeddings": 163840,
22+
"model_type": "deepseek_v3",
23+
"moe_intermediate_size": 2048,
24+
"moe_layer_freq": 1,
25+
"n_group": 8,
26+
"n_routed_experts": 256,
27+
"n_shared_experts": 1,
28+
"norm_topk_prob": true,
29+
"num_attention_heads": 128,
30+
"num_experts_per_tok": 8,
31+
"num_hidden_layers": 61,
32+
"num_key_value_heads": 128,
33+
"num_nextn_predict_layers": 1,
34+
"q_lora_rank": 1536,
35+
"qk_nope_head_dim": 128,
36+
"qk_rope_head_dim": 64,
37+
"rms_norm_eps": 1e-06,
38+
"rope_scaling": {
39+
"beta_fast": 32,
40+
"beta_slow": 1,
41+
"factor": 40,
42+
"mscale": 1.0,
43+
"mscale_all_dim": 1.0,
44+
"original_max_position_embeddings": 4096,
45+
"type": "yarn"
46+
},
47+
"rope_theta": 10000,
48+
"routed_scaling_factor": 2.5,
49+
"scoring_func": "sigmoid",
50+
"tie_word_embeddings": false,
51+
"topk_group": 4,
52+
"topk_method": "noaux_tc",
53+
"torch_dtype": "bfloat16",
54+
"transformers_version": "4.33.1",
55+
"use_cache": true,
56+
"v_head_dim": 128,
57+
"vocab_size": 129280
58+
}

src/MaxText/sharding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636

3737
def get_input_data_sharding(config, mesh):
3838
"""Get the input data sharding for the model"""
39-
return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
39+
if config.enable_diloco:
40+
data_sharding = create_sharding(
41+
mesh, ["diloco"] + config.input_data_sharding_logical_axes, rules=config.logical_axis_rules
42+
)
43+
else:
44+
data_sharding = create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
45+
return data_sharding
4046

4147

4248
def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):

src/MaxText/train_compile.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Sequence
2525
import os
2626
import pickle
27+
import functools
2728

2829
from absl import app
2930

@@ -45,6 +46,7 @@
4546
from maxtext.utils import gcs_utils
4647
from maxtext.utils import max_utils
4748
from maxtext.utils import maxtext_utils
49+
from maxtext.trainers.diloco import diloco
4850

4951
# pylint: disable=too-many-positional-arguments
5052

@@ -235,13 +237,32 @@ def main(argv: Sequence[str]) -> None:
235237

236238
# Get data sharding
237239
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
238-
239-
# Get function to compile and shardings
240-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
241-
maxtext_utils.get_functional_train_with_signature(
242-
train.train_step, data_sharding, state_mesh_shardings, model, config
243-
)
244-
)
240+
if config.enable_diloco:
241+
# Build abstract DiLoCo state and shardings for AOT compilation
242+
abstract_state = shaped_train_args[0]
243+
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
244+
config, abstract_state, state_mesh_shardings, topology_mesh
245+
)
246+
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2])
247+
248+
# Wrap train_step with diloco
249+
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None)
250+
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)
251+
252+
# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
253+
func_to_compile = train_step_fn
254+
func_to_compile.__name__ = "train_step"
255+
in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng
256+
out_shard = (state_mesh_shardings, None) # State, metrics
257+
static_argnums = ()
258+
donate_argnums = 0
259+
else:
260+
# Get function to compile and shardings
261+
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
262+
maxtext_utils.get_functional_train_with_signature(
263+
train.train_step, data_sharding, state_mesh_shardings, model, config
264+
)
265+
)
245266

246267
# print weights sharding info under debug sharding mode
247268
if config.debug_sharding:

src/maxtext/common/data_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
maybe_record_goodput,
2626
)
2727
from maxtext.utils import exceptions
28+
from maxtext.trainers.diloco import diloco
2829

2930

3031
class DataLoader:
@@ -70,10 +71,13 @@ def load_next_batch_pre_sharding(self):
7071

7172
def load_next_batch(self, *args, **kwargs):
7273
"""Loads the next batch with sharding hint"""
73-
return jax.device_put(
74+
example_batch = jax.device_put(
7475
self.load_next_batch_pre_sharding(),
7576
self.input_data_shardings,
7677
)
78+
if self.config.enable_diloco:
79+
example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch)
80+
return example_batch
7781

7882
def check_example_batch(self):
7983
if self.config.max_checkify:

src/maxtext/configs/base.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
400400

401401
# Parallelism
402402
shard_mode: "auto" # can be either auto or explicit
403-
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
403+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
404404
logical_axis_rules: [
405405
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
406406
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
@@ -483,6 +483,7 @@ logical_axis_rules: [
483483
['paged_kv_head_dim_size', []],
484484
['dense_layers', []],
485485
['moe_layers', []],
486+
['diloco', 'diloco'],
486487
]
487488
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
488489
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
@@ -495,6 +496,7 @@ sharding_tolerance: 0.02
495496
# value to auto-shard based on available slices and devices.
496497
# By default, product of the DCN axes should equal number of slices
497498
# and product of the ICI axes should equal number of devices per slice.
499+
dcn_diloco_parallelism: 1
498500
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
499501
dcn_fsdp_parallelism: 1
500502
dcn_fsdp_transpose_parallelism: 1
@@ -507,6 +509,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
507509
dcn_pipeline_parallelism: 1
508510
dcn_expert_parallelism: 1
509511
dcn_autoregressive_parallelism: 1 # never recommended
512+
ici_diloco_parallelism: 1
510513
ici_data_parallelism: 1
511514
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
512515
ici_fsdp_transpose_parallelism: 1
@@ -738,6 +741,12 @@ enable_data_shuffling: True
738741
data_shuffle_seed: 0
739742
init_weights_seed: 0
740743

744+
# DiLoCo params.
745+
enable_diloco: False
746+
diloco_sync_period: 36
747+
diloco_outer_lr: 0.3
748+
diloco_outer_momentum: 0.9
749+
741750
# You may disable clipping by setting gradient_clipping_threshold to zero.
742751
gradient_clipping_threshold: 1.0
743752

src/maxtext/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ class LayoutAndSharding(BaseModel):
784784
class DcnParallelism(BaseModel):
785785
"""Parallelism dimensions across the DCN (Data Center Network)."""
786786

787+
dcn_diloco_parallelism: int = Field(1, description="DCN axis for Diloco parallelism.")
787788
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
788789
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
789790
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
@@ -803,6 +804,7 @@ class DcnParallelism(BaseModel):
803804
class IciParallelism(BaseModel):
804805
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
805806

807+
ici_diloco_parallelism: int = Field(1, description="ICI axis for Diloco parallelism.")
806808
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
807809
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
808810
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
@@ -1082,6 +1084,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
10821084
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
10831085

10841086

1087+
class DilocoParams(BaseModel):
1088+
"""Diloco Hyperparameters"""
1089+
1090+
enable_diloco: bool = Field(False, description="Enable Diloco parallelism")
1091+
diloco_sync_period: int = Field(36, description="Diloco sync period.")
1092+
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
1093+
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")
1094+
1095+
10851096
class Optimizer(BaseModel):
10861097
"""Configuration for the optimizer and learning rate schedule."""
10871098

@@ -1632,6 +1643,11 @@ class DerivedValues(BaseModel):
16321643
description="Effective number of query heads, scaled by `global_parameter_scale`.",
16331644
)
16341645

1646+
num_diloco_replicas: None | int = Field(
1647+
None,
1648+
description="The number of diloco replicas, derived from ICI and DCN values.",
1649+
)
1650+
16351651
ici_parallelism: None | list[int] = Field(
16361652
None,
16371653
description="Aggregated list of all ICI parallelism values for legacy compatibility.",
@@ -1779,6 +1795,7 @@ class MaxTextConfig(
17791795
RematAndOffload,
17801796
TrainingLoop,
17811797
ManifoldConstrainedHyperConnections,
1798+
DilocoParams,
17821799
Optimizer,
17831800
AdamW,
17841801
Muon,
@@ -2375,6 +2392,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23752392
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
23762393
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
23772394
self.ici_parallelism = [
2395+
self.ici_diloco_parallelism,
23782396
self.ici_pipeline_parallelism,
23792397
self.ici_data_parallelism,
23802398
self.ici_fsdp_parallelism,
@@ -2389,6 +2407,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23892407
self.ici_autoregressive_parallelism,
23902408
]
23912409
self.dcn_parallelism = [
2410+
self.dcn_diloco_parallelism,
23922411
self.dcn_pipeline_parallelism,
23932412
self.dcn_data_parallelism,
23942413
self.dcn_fsdp_parallelism,
@@ -2404,6 +2423,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24042423
]
24052424
else:
24062425
ici_map = {
2426+
"diloco": self.ici_diloco_parallelism,
24072427
"data": self.ici_data_parallelism,
24082428
"stage": self.ici_pipeline_parallelism,
24092429
"fsdp": self.ici_fsdp_parallelism,
@@ -2422,6 +2442,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24222442
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
24232443

24242444
dcn_map = {
2445+
"diloco": self.dcn_diloco_parallelism,
24252446
"data": self.dcn_data_parallelism,
24262447
"stage": self.dcn_pipeline_parallelism,
24272448
"fsdp": self.dcn_fsdp_parallelism,
@@ -2439,6 +2460,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24392460
}
24402461
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
24412462

2463+
# Diloco params
2464+
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)
2465+
24422466
# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
24432467
if isinstance(self.decoder_block, str):
24442468
self.decoder_block = DecoderBlockType(self.decoder_block.lower())

0 commit comments

Comments
 (0)