Skip to content

Commit e009264

Browse files
committed
Merge pull request #2999 from AI-Hypercomputer:rbierneni-qwen3-next-tflops
PiperOrigin-RevId: 863406702
2 parents 41be8da + 3fcbcc7 commit e009264

3 files changed

Lines changed: 235 additions & 2 deletions

File tree

src/MaxText/configs/models/qwen3-next-80b-a3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ normalization_layer_epsilon: 1.0e-6
3131
base_mlp_dim: 512
3232
base_moe_mlp_dim: 512
3333
num_experts: 512
34+
shared_experts: 1
3435
num_experts_per_tok: 10
3536
norm_topk_prob: True
3637

src/maxtext/utils/maxtext_utils.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,82 @@ def get_dense_moe_layers(config):
488488
elif config.decoder_block == DecoderBlockType.LLAMA4:
489489
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
490490
num_dense_layers = config.num_decoder_layers - num_moe_layers
491+
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
492+
num_moe_layers = config.num_decoder_layers
493+
num_dense_layers = 0
491494
else:
492-
raise ValueError("Currently we only support DeepSeek and Llama4 calculation.")
495+
raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.")
493496

494497
return num_dense_layers, num_moe_layers
495498

496499

500+
def calculate_gated_delta_net_flops_per_device(config):
501+
"""Calculates the FLOPs for a single Gated Delta Net (Linear Attention) layer."""
502+
B = config.per_device_batch_size
503+
S = config.max_target_length
504+
E = config.emb_dim
505+
506+
H_k = config.gdn_num_key_heads
507+
H_v = config.gdn_num_value_heads
508+
D_k = config.gdn_key_head_dim
509+
D_v = config.gdn_value_head_dim
510+
C = config.gdn_chunk_size
511+
K_conv = config.gdn_conv_kernel_dim
512+
513+
K_dim = H_k * D_k
514+
V_dim = H_v * D_v
515+
516+
# 1. Projections (Learnable Weights)
517+
# in_proj_qkvz: E -> 2*K_dim + 2*V_dim
518+
flops_qkvz = 2 * B * S * E * (2 * K_dim + 2 * V_dim)
519+
# in_proj_ba: E -> 2*H_v
520+
flops_ba = 2 * B * S * E * (2 * H_v)
521+
# out_proj: V_dim -> E
522+
flops_out = 2 * B * S * V_dim * E
523+
524+
flops_projections = flops_qkvz + flops_ba + flops_out
525+
526+
# 2. Convolution (Learnable Weights)
527+
# Depthwise conv on dim (2*K_dim + V_dim)
528+
# 2 * B * S * Channels * Kernel
529+
flops_conv = 2 * B * S * (2 * K_dim + V_dim) * K_conv
530+
531+
# 3. Core Gated Delta Net (Attention-like operations)
532+
# Assumptions:
533+
# H = H_v (broadcasting K to V heads if H_v > H_k)
534+
# N = num_chunks & N * C ~ S
535+
#
536+
# Query (Q): [B, S, H_v, D_k]
537+
# Keys (K): [B, S, H_v, D_k]
538+
# Values (V): [B, S, H_v, D_v]
539+
# Intra-Chunk Attention (A): [B, N, H_v, C, C]
540+
# Recurrent State (S): [B, N, H_v, D_k, D_v]
541+
542+
# - Intra-chunk terms (per chunk C):
543+
# - attn (K*K): 2 * B * S * H_v * C * D_k
544+
# - val_intra (A*V): 2 * B * S * H_v * C * D_v
545+
# - k_cum (A*K): 2 * B * S * H_v * C * D_k
546+
# - inner_attn_body loop (iterative refinement): ≈ (C - 1) * B * H * N * C^2 ≈ B * H * S * C^2
547+
flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2)
548+
549+
# - Inter-chunk terms (Recurrent State D_k * D_v):
550+
# - attn_i (Q*K): 2 * B * S * H_v * C * D_k
551+
# - v_prime (K*S): 2 * B * S * H_v * D_k * D_v
552+
# - attn_inter (Q*S): 2 * B * S * H_v * D_k * D_v
553+
# - core_out (A*V): 2 * B * S * H_v * C * D_v
554+
# - update (K*V): 2 * B * S * H_v * D_k * D_v
555+
flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v)
556+
557+
flops_core = flops_intra + flops_inter
558+
559+
# Weights part: Projections + Conv
560+
gdn_weight_flops = flops_projections + flops_conv
561+
# Attention part: Core
562+
gdn_attn_flops = flops_core
563+
564+
return gdn_weight_flops, gdn_attn_flops
565+
566+
497567
def calculate_gemma3_vision_layers_tflops_per_device(config):
498568
"""
499569
Estimate TFLOPs for Gemma3 vision encoder (ViT-style).
@@ -634,7 +704,7 @@ def calculate_tflops_training_per_device(config, log=True):
634704
# MLP flops
635705
if config.num_experts > 1:
636706
# calculation based on dropless implementation
637-
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4):
707+
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT):
638708
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
639709
else:
640710
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
@@ -702,6 +772,24 @@ def calculate_tflops_training_per_device(config, log=True):
702772
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
703773
)
704774
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
775+
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
776+
gdn_weight_flops_per_layer, gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device(config)
777+
cycle_interval = config.inhomogeneous_layer_cycle_interval
778+
num_full_attn_layers = config.num_decoder_layers // cycle_interval
779+
num_linear_attn_layers = config.num_decoder_layers - num_full_attn_layers
780+
781+
# Weights TFLOPs:
782+
total_weights = (
783+
total_ffn_flops
784+
+ embedding_flops
785+
+ (qkv_flops + projection_flops) * num_full_attn_layers
786+
+ gdn_weight_flops_per_layer * num_linear_attn_layers
787+
)
788+
learnable_weight_tflops = total_weights * 3 / 10**12
789+
790+
# Attention TFLOPs:
791+
total_attn = (causal_attention_flops * num_full_attn_layers) + (gdn_attn_flops_per_layer * num_linear_attn_layers)
792+
attention_tflops = total_attn * 3 / 10**12
705793
else:
706794
# multiply by 3 for both feed forward and back propagation flops
707795
learnable_weight_tflops = (

tests/unit/flop_calculation_test.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,150 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float:
9898

9999
return attention_flops / 1e12 # return tflops
100100

101+
def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float:
102+
"""
103+
Computes the total training TFLOPs per device for a Qwen3-Next model.
104+
Only counts the attention mechanism operations (non-weights).
105+
"""
106+
B = kwargs["per_device_batch_size"]
107+
S = kwargs["max_target_length"]
108+
N = kwargs["base_num_decoder_layers"]
109+
cycle_interval = kwargs["inhomogeneous_layer_cycle_interval"]
110+
111+
# Layer counts
112+
num_full_layers = N // cycle_interval
113+
num_linear_layers = N - num_full_layers
114+
115+
# 1. Full Attention FLOPs (Causal)
116+
D_head = kwargs["head_dim"]
117+
H_q = kwargs["base_num_query_heads"]
118+
# 2 for QK^T and SV, 3 for fwd+bwd.
119+
# Note: maxtext_utils divides by 2 for causal masking.
120+
# Formula: 2 * 3 * B * S^2 * H * D
121+
full_attn_flops = 2 * 3 * num_full_layers * B * (S**2) * H_q * D_head
122+
123+
# 2. Linear Attention (Gated Delta Net) FLOPs
124+
H_v = kwargs["gdn_num_value_heads"]
125+
D_k = kwargs["gdn_key_head_dim"]
126+
D_v = kwargs["gdn_value_head_dim"]
127+
C = kwargs["gdn_chunk_size"]
128+
129+
# Formulas from maxtext_utils.calculate_gated_delta_net_flops_per_device
130+
flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2)
131+
flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v)
132+
133+
# 3 for fwd+bwd
134+
linear_attn_flops = 3 * num_linear_layers * (flops_intra + flops_inter)
135+
136+
return (full_attn_flops + linear_attn_flops) / 1e12
137+
138+
@pytest.mark.cpu_only
139+
def test_qwen3_next_flops(self):
140+
"""Test Qwen3-Next Flops calculation"""
141+
kwargs = {
142+
"model_name": "qwen3-next-80b-a3b",
143+
"override_model_config": True,
144+
"per_device_batch_size": 1,
145+
"max_target_length": 4096,
146+
"decoder_block": "qwen3_next",
147+
"gradient_accumulation_steps": 1,
148+
"skip_jax_distributed_system": True,
149+
# Core Architectural Parameters
150+
"base_emb_dim": 2048,
151+
"base_num_decoder_layers": 48,
152+
"base_num_query_heads": 16,
153+
"base_num_kv_heads": 2,
154+
"head_dim": 256,
155+
"vocab_size": 151936,
156+
# MoE Parameters
157+
"base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations
158+
"base_moe_mlp_dim": 512,
159+
"num_experts": 512,
160+
"num_experts_per_tok": 10,
161+
"mlp_activations": ["silu", "linear"],
162+
# Qwen3-Next Specific Parameters
163+
"inhomogeneous_layer_cycle_interval": 4,
164+
"gdn_conv_kernel_dim": 4,
165+
"gdn_key_head_dim": 128,
166+
"gdn_value_head_dim": 128,
167+
"gdn_num_key_heads": 16,
168+
"gdn_num_value_heads": 32,
169+
"gdn_chunk_size": 64,
170+
}
171+
172+
# 1. Calculate Attention TFLOPs
173+
attention_tflops = self.compute_qwen3_next_attention_flops_per_device(kwargs)
174+
175+
# 2. Calculate Learnable Weight Active Params
176+
# Config Shortcuts
177+
emb_dim = kwargs["base_emb_dim"]
178+
vocab = kwargs["vocab_size"]
179+
N = kwargs["base_num_decoder_layers"]
180+
181+
# MoE Active Params (per layer)
182+
# FFN uses SwiGLU (3 matrices), Qwen3-Next has 1 shared + N routed experts
183+
# Params = Gate + Shared + Routed
184+
# Gate: emb_dim * num_experts
185+
# Expert: 3 * emb_dim * moe_mlp_dim
186+
moe_mlp_dim = kwargs["base_moe_mlp_dim"]
187+
num_experts = kwargs["num_experts"]
188+
num_routed = kwargs["num_experts_per_tok"]
189+
190+
params_moe_layer = (
191+
(emb_dim * num_experts) + (3 * emb_dim * moe_mlp_dim * 1) + (3 * emb_dim * moe_mlp_dim * num_routed)
192+
)
193+
194+
# Full Attention Params (per full layer)
195+
Hq = kwargs["base_num_query_heads"]
196+
Hkv = kwargs["base_num_kv_heads"]
197+
Hd = kwargs["head_dim"]
198+
# Q, K, V, Out projections
199+
params_full_attn = (emb_dim * (Hq + 2 * Hkv) * Hd) + (Hq * Hd * emb_dim)
200+
201+
# GDN Linear Attention Params (per linear layer)
202+
Hk_g = kwargs["gdn_num_key_heads"]
203+
Hv_g = kwargs["gdn_num_value_heads"]
204+
Dk_g = kwargs["gdn_key_head_dim"]
205+
Dv_g = kwargs["gdn_value_head_dim"]
206+
K_conv = kwargs["gdn_conv_kernel_dim"]
207+
208+
K_dim = Hk_g * Dk_g
209+
V_dim = Hv_g * Dv_g
210+
211+
# Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in)
212+
params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim)) + (emb_dim * 2 * Hv_g) + (V_dim * emb_dim)
213+
# Conv: depthwise on 2K+V
214+
params_gdn_conv = (2 * K_dim + V_dim) * K_conv
215+
216+
params_gdn_layer = params_gdn_proj + params_gdn_conv
217+
218+
# Total Active Params
219+
# 12 Full Layers, 36 Linear Layers
220+
num_full = N // kwargs["inhomogeneous_layer_cycle_interval"]
221+
num_linear = N - num_full
222+
223+
total_active_params = (
224+
(vocab * emb_dim)
225+
+ (num_full * (params_full_attn + params_moe_layer))
226+
+ (num_linear * (params_gdn_layer + params_moe_layer))
227+
)
228+
229+
# Weight TFLOPs = 6 * B * S * P
230+
B = kwargs["per_device_batch_size"]
231+
S = kwargs["max_target_length"]
232+
weight_tflops = 6 * B * S * total_active_params / 1e12
233+
234+
golden_tflops = weight_tflops + attention_tflops
235+
236+
# Run Calculation
237+
cfg = pyconfig.initialize(
238+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
239+
**kwargs,
240+
)
241+
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
242+
243+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
244+
101245
@pytest.mark.cpu_only
102246
def test_llama2_7b_flops(self):
103247
"""Test Llama2 7b Flops calculation with default parameters"""

0 commit comments

Comments
 (0)