Skip to content

Commit 79eecc9

Browse files
Introduce a DSv3 config with 2 logical fsdp sharding axes.
PiperOrigin-RevId: 843675516
1 parent 69adf5d commit 79eecc9

5 files changed

Lines changed: 197 additions & 94 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ run_name: ""
1818

1919
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
2020
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
21+
override_logical_axis_rules: False # When set overrides logical axis rules instead of merging them.
2122
debug:
2223
rl: False # RL-specific debugging
2324

@@ -70,72 +71,72 @@ checkpoint_storage_concurrent_gb: 96
7071

7172
# Bool flag for enabling Orbax v1.
7273
enable_orbax_v1: False
73-
# Function for processing loaded checkpoint dict into a format MaxText can understand. (for other formats, i.e. SafeTensors)
74-
checkpoint_conversion_fn: None
75-
# Optional checkpoint context to use for loading. Options: "orbax", "safetensors"
74+
# function for processing loaded checkpoint dict into a format maxtext can understand. (for other formats, i.e. safetensors)
75+
checkpoint_conversion_fn: none
76+
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
7677
source_checkpoint_layout: "orbax"
77-
############################### END CHECKPOINTING ##################################
78+
############################### end checkpointing ##################################
7879

7980

80-
reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.
81+
reuse_example_batch: 0 # for testing tpu performance, this options repeated uses the same batch.
8182

8283

83-
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
84-
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
85-
gcs_metrics: False
84+
metrics_file: "" # for testing, local file that stores scalar metrics. if empty, no metrics are written.
85+
# if true save metrics such as loss and tflops to gcs in {base_output_directory}/{run_name}/metrics/
86+
gcs_metrics: false
8687

87-
# If true save config to GCS in {base_output_directory}/{run_name}/
88-
save_config_to_gcs: False
88+
# if true save config to gcs in {base_output_directory}/{run_name}/
89+
save_config_to_gcs: false
8990

90-
# Gradient dtype
91+
# gradient dtype
9192
grad_dtype: "float32"
9293

93-
# Activation dtypes.
94+
# activation dtypes.
9495
dtype: "bfloat16"
95-
# Used to configure quantization in the transformer layers, defaults to null implying bf16.
96-
# Possible alternative settings are as follows:
96+
# used to configure quantization in the transformer layers, defaults to null implying bf16.
97+
# possible alternative settings are as follows:
9798
# 'int8' for dynamic range quantization using 8-bits
98-
# 'intmp' for mixed precision quantization for inference as described here: src/MaxText/configs/quantization/README.md
99-
# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
100-
# 'nanoo_fp8' for 8-bit floating-point GeMMs on AMD MI300/MI325 GPUs.
101-
# 'fp8_full' for FP8 quantization with static scaling.
99+
# 'intmp' for mixed precision quantization for inference as described here: src/MaxText/configs/quantization/readme.md
100+
# 'fp8' for 8-bit floating-point gemms on nvidia gpus.
101+
# 'nanoo_fp8' for 8-bit floating-point gemms on amd mi300/mi325 gpus.
102+
# 'fp8_full' for fp8 quantization with static scaling.
102103
quantization: ""
103-
# Used to configure constant_bound_config in aqt lib for static scaling, e.g. constant_bound_config='0.5, 0.5, 0.5, 0.5, 0.5, 0.5'
104+
# used to configure constant_bound_config in aqt lib for static scaling, e.g. constant_bound_config='0.5, 0.5, 0.5, 0.5, 0.5, 0.5'
104105
constant_bound_config: ""
105-
# Choose one of default, high, and highest.
106-
# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
106+
# choose one of default, high, and highest.
107+
# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.precision
107108
matmul_precision: "default"
108-
activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype
109-
# Used to replicate the quantization scale to avoid the inefficient XLA fusion for 2d sharding.
110-
replicate_quant_scale: False
111-
# Path to file with quantization config for intmp.
109+
activations_in_float32: false # sets activations to float32 before nonlinearity it true, else dtype
110+
# used to replicate the quantization scale to avoid the inefficient xla fusion for 2d sharding.
111+
replicate_quant_scale: false
112+
# path to file with quantization config for intmp.
112113
quant_cfg_path: ""
113-
quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False
114-
# Valid kv_quant_axis values:
115-
# - "" is valid only when quantize_kvcache is False
114+
quantize_kvcache: false # set to true to quantize kv cache values, defaults to false
115+
# valid kv_quant_axis values:
116+
# - "" is valid only when quantize_kvcache is false
116117
# - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis
117118
# - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes
118-
# Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False
119+
# default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is false
119120
# - "dkv" is expected with better accuracy but degraded computation
120121
kv_quant_axis: "heads_and_dkv"
121122
kv_quant_dtype: "int8"
122-
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
123-
# Saves params quantized on fly at following path
123+
checkpoint_is_quantized: false # set to true if reading from a saved aqt quantized checkpoint
124+
# saves params quantized on fly at following path
124125
save_quantized_params_path: ""
125-
#Used to configure the mode in which model is called
126+
#used to configure the mode in which model is called
126127
# when left as is, corresponds to training
127128
# accepted values are "inference"
128129
model_call_mode: ""
129-
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix.
130-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
130+
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
131+
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
131132
weight_quantization_calibration_method: "absmax"
132133
act_quantization_calibration_method: "absmax"
133134
bwd_quantization_calibration_method: "absmax"
134-
# Shard the range finding operation for quantization. By default this is set to number of slices.
135+
# shard the range finding operation for quantization. by default this is set to number of slices.
135136
quantization_local_shard_count: -1
136137

137-
decoder_block: "llama2" # which style of DecoderBlock to use.
138-
# Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes
138+
decoder_block: "llama2" # which style of decoderblock to use.
139+
# global parameter scale needs to be a power of 2. if you want finer grained control of the model sizes
139140
# then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads,
140141
# base_mlp_dim, base_num_decoder_layers and/or head_dim.
141142
weight_dtype: "float32"
@@ -149,39 +150,39 @@ head_dim: 128
149150
mlp_activations: ["silu", "linear"]
150151
mlp_activations_limit: -1.0
151152
dropout_rate: 0.0
152-
logits_via_embedding: False
153-
normalize_embedding_logits: True # whether to normalize pre-softmax logits if logits_via_embedding is true
154-
logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability
155-
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
156-
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
157-
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
158-
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe
159-
160-
# Multi-Token Prediction Configs
161-
# The number of auxiliary prediction layers to use for MTP.
162-
# Set to 0 to disable the feature.
153+
logits_via_embedding: false
154+
normalize_embedding_logits: true # whether to normalize pre-softmax logits if logits_via_embedding is true
155+
logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embedding dot product for stability
156+
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
157+
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
158+
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
159+
float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe
160+
161+
# multi-token prediction configs
162+
# the number of auxiliary prediction layers to use for mtp.
163+
# set to 0 to disable the feature.
163164
mtp_num_layers: 0
164-
# The scaling factor (lambda) for the MTP auxiliary loss. The final loss is:
165+
# the scaling factor (lambda) for the mtp auxiliary loss. the final loss is:
165166
# main_loss + mtp_loss_scaling_factor * avg_mtp_loss
166167
mtp_loss_scaling_factor: 0.1
167-
# Specifies which MTP layer (1-indexed) is used to calculate metrics like the
168-
# acceptance rate during evaluation. For example, a value of `1` targets the
169-
# first auxiliary prediction head. Set to 0 to disable this specific evaluation
168+
# specifies which mtp layer (1-indexed) is used to calculate metrics like the
169+
# acceptance rate during evaluation. for example, a value of `1` targets the
170+
# first auxiliary prediction head. set to 0 to disable this specific evaluation
170171
mtp_eval_target_module: 0
171172

172173
# mixture of experts (moe)
173174
num_experts: 1
174175
num_experts_per_tok: 1
175-
megablox: True
176-
sparse_matmul: True
176+
megablox: true
177+
sparse_matmul: true
177178
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
178179
load_balance_loss_weight: 0.01 # weight for the load balance loss
179-
use_random_routing: False # whether to use random routing for debug/test purpose
180-
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
181-
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
182-
# Tunable tiling dimensions used for MLP GMM
183-
# Megablox/JAX Ragged Dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
184-
# Tokamax Ragged Dot - supports all 18 configs
180+
use_random_routing: false # whether to use random routing for debug/test purpose
181+
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
182+
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
183+
# tunable tiling dimensions used for mlp gmm
184+
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
185+
# tokamax ragged dot - supports all 18 configs
185186
wi_tile_fwd_batch_seq: 512
186187
wi_tile_fwd_embed_dim: 1024
187188
wi_tile_fwd_mlp_dim: 1024
@@ -201,17 +202,19 @@ wo_tile_dlhs_mlp_dim: 1024
201202
wo_tile_drhs_batch_seq: 512
202203
wo_tile_drhs_embed_dim: 1024
203204
wo_tile_drhs_mlp_dim: 1024
204-
norm_topk_prob: False # Boolean to enable the top-k probability normalization. Qwen3-specific normalization of router weights.
205+
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
205206

206-
# How the expert axis is used to shard attention weights and activations
207+
# how the expert axis is used to shard attention weights and activations
207208
# "fsdp" (ep acts as fsdp parallelism)
208209
# "context" (ep acts as context parallelism, training only)
209210
expert_shard_attention_option: "fsdp"
210211

211-
# When MoE weight matrices are sharded on both FSDP and FSDP-transpose axes, use two separate All-Gather calls
212-
moe_fsdp_use_two_stage_all_gather: False
212+
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
213+
moe_fsdp_use_two_stage_all_gather: false
213214
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
214215
fsdp_shard_on_exp: False
216+
# use fsdp and fsdp_transpose axes for sharding the moe weights
217+
use_2d_fsdp_sharding: False
215218

216219
# deepseek moe
217220
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes
16+
17+
# For DeepSeek default device-limited routing,
18+
# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
19+
20+
base_emb_dim: 7168
21+
base_num_query_heads: 128
22+
base_num_kv_heads: 128
23+
base_mlp_dim: 18432
24+
base_moe_mlp_dim: 2048
25+
base_num_decoder_layers: 61
26+
first_num_dense_layers: 3
27+
mlp_activations: ["silu","linear"]
28+
vocab_size: 129280
29+
enable_dropout: False
30+
logits_via_embedding: False
31+
normalization_layer_epsilon: 1.0e-6
32+
num_experts: 256
33+
num_experts_per_tok: 8
34+
shared_experts: 1
35+
routed_scaling_factor: 2.5
36+
routed_score_func: "sigmoid"
37+
routed_bias: True
38+
decoder_block: "deepseek"
39+
# MLA
40+
attention_type: "mla"
41+
q_lora_rank: 1536
42+
kv_lora_rank: 512
43+
qk_nope_head_dim: 128
44+
qk_rope_head_dim: 64
45+
v_head_dim: 128
46+
mscale: 1.0
47+
# RoPE
48+
rope_type: "yarn"
49+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
50+
max_position_embeddings: 163840
51+
original_max_position_embeddings: 4096
52+
rope_factor: 40
53+
beta_fast: 32
54+
rope_interleave: True
55+
rope_truncate: True
56+
rope_attention_scaling: False
57+
58+
override_logical_axis_rules: True
59+
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
61+
logical_axis_rules: [
62+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
64+
['activation_heads', []],
65+
['embed', ['fsdp']],
66+
['embed_no_exp', ['fsdp']],
67+
['q_lora', ['fsdp']],
68+
['kv_lora', ['fsdp']],
69+
['q_lora_up_proj', ['fsdp_transpose']],
70+
['kv_lora_up_proj', ['fsdp_transpose']],
71+
['q_heads', ['fsdp_transpose']],
72+
['kv_heads', ['fsdp_transpose']],
73+
['heads', ['fsdp_transpose']],
74+
['mlp', ['fsdp_transpose']],
75+
]

src/MaxText/configs/types.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ProfilerType(str, Enum):
187187
"deepseek2-16b",
188188
"deepseek2-236b",
189189
"deepseek3-671b",
190+
"deepseek3-671b-2dfsdp",
190191
"deepseek3-test",
191192
"deepseek3-tiny",
192193
"kimi-k2-1t",
@@ -233,6 +234,10 @@ class RunInfo(BaseModel):
233234
)
234235
model_name: ModelName = Field("default", description="The name of the model configuration to use.")
235236
override_model_config: bool = Field(False, description="If True, allows overriding model parameters via CLI.")
237+
override_logical_axis_rules: bool = Field(
238+
False,
239+
description="If True, logical_axis_rules will be overridden instead of merged.",
240+
)
236241
log_config: bool = Field(
237242
True,
238243
description="If True, prints the final configuration after initialization.",
@@ -563,6 +568,10 @@ class MoEGeneral(BaseModel):
563568
description="Shard the MoE weights on the num_expert dimension. Can be performant when "
564569
"num_experts % fsdp_parallelism != 0.",
565570
)
571+
use_2d_fsdp_sharding: bool = Field(
572+
False,
573+
description="Use `fsdp` and `fsdp_transpose` axes for 2D FSDP sharding.",
574+
)
566575
norm_topk_prob: bool = Field(
567576
False,
568577
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
@@ -2137,34 +2146,37 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21372146
self.dcn_autoregressive_parallelism,
21382147
]
21392148
else:
2140-
self.ici_parallelism = [
2141-
self.ici_data_parallelism,
2142-
self.ici_pipeline_parallelism,
2143-
self.ici_fsdp_parallelism,
2144-
self.ici_fsdp_transpose_parallelism,
2145-
self.ici_sequence_parallelism,
2146-
self.ici_context_parallelism,
2147-
self.ici_context_autoregressive_parallelism,
2148-
self.ici_tensor_parallelism,
2149-
self.ici_tensor_transpose_parallelism,
2150-
self.ici_tensor_sequence_parallelism,
2151-
self.ici_expert_parallelism,
2152-
self.ici_autoregressive_parallelism,
2153-
]
2154-
self.dcn_parallelism = [
2155-
self.dcn_data_parallelism,
2156-
self.dcn_pipeline_parallelism,
2157-
self.dcn_fsdp_parallelism,
2158-
self.dcn_fsdp_transpose_parallelism,
2159-
self.dcn_sequence_parallelism,
2160-
self.dcn_context_parallelism,
2161-
self.dcn_context_autoregressive_parallelism,
2162-
self.dcn_tensor_parallelism,
2163-
self.dcn_tensor_transpose_parallelism,
2164-
self.dcn_tensor_sequence_parallelism,
2165-
self.dcn_expert_parallelism,
2166-
self.dcn_autoregressive_parallelism,
2167-
]
2149+
ici_map = {
2150+
"data": self.ici_data_parallelism,
2151+
"stage": self.ici_pipeline_parallelism,
2152+
"fsdp": self.ici_fsdp_parallelism,
2153+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2154+
"sequence": self.ici_sequence_parallelism,
2155+
"context": self.ici_context_parallelism,
2156+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2157+
"tensor": self.ici_tensor_parallelism,
2158+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2159+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2160+
"expert": self.ici_expert_parallelism,
2161+
"autoregressive": self.ici_autoregressive_parallelism,
2162+
}
2163+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2164+
2165+
dcn_map = {
2166+
"data": self.dcn_data_parallelism,
2167+
"stage": self.dcn_pipeline_parallelism,
2168+
"fsdp": self.dcn_fsdp_parallelism,
2169+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2170+
"sequence": self.dcn_sequence_parallelism,
2171+
"context": self.dcn_context_parallelism,
2172+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2173+
"tensor": self.dcn_tensor_parallelism,
2174+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2175+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2176+
"expert": self.dcn_expert_parallelism,
2177+
"autoregressive": self.dcn_autoregressive_parallelism,
2178+
}
2179+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
21682180

21692181
# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
21702182
if isinstance(self.decoder_block, str):

0 commit comments

Comments
 (0)