@@ -98,6 +98,150 @@ def compute_gpt_attention_flops_per_device(self, kwargs: dict) -> float:
9898
9999 return attention_flops / 1e12 # return tflops
100100
101+ def compute_qwen3_next_attention_flops_per_device (self , kwargs : dict ) -> float :
102+ """
103+ Computes the total training TFLOPs per device for a Qwen3-Next model.
104+ Only counts the attention mechanism operations (non-weights).
105+ """
106+ B = kwargs ["per_device_batch_size" ]
107+ S = kwargs ["max_target_length" ]
108+ N = kwargs ["base_num_decoder_layers" ]
109+ cycle_interval = kwargs ["inhomogeneous_layer_cycle_interval" ]
110+
111+ # Layer counts
112+ num_full_layers = N // cycle_interval
113+ num_linear_layers = N - num_full_layers
114+
115+ # 1. Full Attention FLOPs (Causal)
116+ D_head = kwargs ["head_dim" ]
117+ H_q = kwargs ["base_num_query_heads" ]
118+ # 2 for QK^T and SV, 3 for fwd+bwd.
119+ # Note: maxtext_utils divides by 2 for causal masking.
120+ # Formula: 2 * 3 * B * S^2 * H * D
121+ full_attn_flops = 2 * 3 * num_full_layers * B * (S ** 2 ) * H_q * D_head
122+
123+ # 2. Linear Attention (Gated Delta Net) FLOPs
124+ H_v = kwargs ["gdn_num_value_heads" ]
125+ D_k = kwargs ["gdn_key_head_dim" ]
126+ D_v = kwargs ["gdn_value_head_dim" ]
127+ C = kwargs ["gdn_chunk_size" ]
128+
129+ # Formulas from maxtext_utils.calculate_gated_delta_net_flops_per_device
130+ flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v ) + (B * H_v * S * C ** 2 )
131+ flops_inter = (2 * B * S * H_v * C * (D_k + D_v )) + (6 * B * S * H_v * D_k * D_v )
132+
133+ # 3 for fwd+bwd
134+ linear_attn_flops = 3 * num_linear_layers * (flops_intra + flops_inter )
135+
136+ return (full_attn_flops + linear_attn_flops ) / 1e12
137+
138+ @pytest .mark .cpu_only
139+ def test_qwen3_next_flops (self ):
140+ """Test Qwen3-Next Flops calculation"""
141+ kwargs = {
142+ "model_name" : "qwen3-next-80b-a3b" ,
143+ "override_model_config" : True ,
144+ "per_device_batch_size" : 1 ,
145+ "max_target_length" : 4096 ,
146+ "decoder_block" : "qwen3_next" ,
147+ "gradient_accumulation_steps" : 1 ,
148+ "skip_jax_distributed_system" : True ,
149+ # Core Architectural Parameters
150+ "base_emb_dim" : 2048 ,
151+ "base_num_decoder_layers" : 48 ,
152+ "base_num_query_heads" : 16 ,
153+ "base_num_kv_heads" : 2 ,
154+ "head_dim" : 256 ,
155+ "vocab_size" : 151936 ,
156+ # MoE Parameters
157+ "base_mlp_dim" : 512 , # Note: maxtext_utils uses moe_mlp_dim for calculations
158+ "base_moe_mlp_dim" : 512 ,
159+ "num_experts" : 512 ,
160+ "num_experts_per_tok" : 10 ,
161+ "mlp_activations" : ["silu" , "linear" ],
162+ # Qwen3-Next Specific Parameters
163+ "inhomogeneous_layer_cycle_interval" : 4 ,
164+ "gdn_conv_kernel_dim" : 4 ,
165+ "gdn_key_head_dim" : 128 ,
166+ "gdn_value_head_dim" : 128 ,
167+ "gdn_num_key_heads" : 16 ,
168+ "gdn_num_value_heads" : 32 ,
169+ "gdn_chunk_size" : 64 ,
170+ }
171+
172+ # 1. Calculate Attention TFLOPs
173+ attention_tflops = self .compute_qwen3_next_attention_flops_per_device (kwargs )
174+
175+ # 2. Calculate Learnable Weight Active Params
176+ # Config Shortcuts
177+ emb_dim = kwargs ["base_emb_dim" ]
178+ vocab = kwargs ["vocab_size" ]
179+ N = kwargs ["base_num_decoder_layers" ]
180+
181+ # MoE Active Params (per layer)
182+ # FFN uses SwiGLU (3 matrices), Qwen3-Next has 1 shared + N routed experts
183+ # Params = Gate + Shared + Routed
184+ # Gate: emb_dim * num_experts
185+ # Expert: 3 * emb_dim * moe_mlp_dim
186+ moe_mlp_dim = kwargs ["base_moe_mlp_dim" ]
187+ num_experts = kwargs ["num_experts" ]
188+ num_routed = kwargs ["num_experts_per_tok" ]
189+
190+ params_moe_layer = (
191+ (emb_dim * num_experts ) + (3 * emb_dim * moe_mlp_dim * 1 ) + (3 * emb_dim * moe_mlp_dim * num_routed )
192+ )
193+
194+ # Full Attention Params (per full layer)
195+ Hq = kwargs ["base_num_query_heads" ]
196+ Hkv = kwargs ["base_num_kv_heads" ]
197+ Hd = kwargs ["head_dim" ]
198+ # Q, K, V, Out projections
199+ params_full_attn = (emb_dim * (Hq + 2 * Hkv ) * Hd ) + (Hq * Hd * emb_dim )
200+
201+ # GDN Linear Attention Params (per linear layer)
202+ Hk_g = kwargs ["gdn_num_key_heads" ]
203+ Hv_g = kwargs ["gdn_num_value_heads" ]
204+ Dk_g = kwargs ["gdn_key_head_dim" ]
205+ Dv_g = kwargs ["gdn_value_head_dim" ]
206+ K_conv = kwargs ["gdn_conv_kernel_dim" ]
207+
208+ K_dim = Hk_g * Dk_g
209+ V_dim = Hv_g * Dv_g
210+
211+ # Projections: qkvz (in->2K+2V), ba (in->2Hv), out (V->in)
212+ params_gdn_proj = (emb_dim * (2 * K_dim + 2 * V_dim )) + (emb_dim * 2 * Hv_g ) + (V_dim * emb_dim )
213+ # Conv: depthwise on 2K+V
214+ params_gdn_conv = (2 * K_dim + V_dim ) * K_conv
215+
216+ params_gdn_layer = params_gdn_proj + params_gdn_conv
217+
218+ # Total Active Params
219+ # 12 Full Layers, 36 Linear Layers
220+ num_full = N // kwargs ["inhomogeneous_layer_cycle_interval" ]
221+ num_linear = N - num_full
222+
223+ total_active_params = (
224+ (vocab * emb_dim )
225+ + (num_full * (params_full_attn + params_moe_layer ))
226+ + (num_linear * (params_gdn_layer + params_moe_layer ))
227+ )
228+
229+ # Weight TFLOPs = 6 * B * S * P
230+ B = kwargs ["per_device_batch_size" ]
231+ S = kwargs ["max_target_length" ]
232+ weight_tflops = 6 * B * S * total_active_params / 1e12
233+
234+ golden_tflops = weight_tflops + attention_tflops
235+
236+ # Run Calculation
237+ cfg = pyconfig .initialize (
238+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
239+ ** kwargs ,
240+ )
241+ calculated_tflops , _ , _ = calculate_tflops_training_per_device (cfg )
242+
243+ self .assertFlopsAlmostEqual (calculated_tflops , golden_tflops )
244+
101245 @pytest .mark .cpu_only
102246 def test_llama2_7b_flops (self ):
103247 """Test Llama2 7b Flops calculation with default parameters"""
0 commit comments