Skip to content

Commit 6b891da

Browse files
authored
[Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660)
* enable trtllm_all_reduce fusion kernel in glm model * fix conflict * format update * fix a bug * modify test * modify test * support empty tensor and modify test * fix test_linear config issues * modify test name * add edge test case * modify format * fix conflict * modify default max token num in trtllm_allreduce_fusion * add max token num branch for trtllm_allreduce_fusion * fix format * fix rmsnorm config issue * modify 2025 to 2026 * using compat grard * Lazily import flashinfer.comm and fix test config issue * fix test issues * add flashinfer cache dir clean machine * fix some issues
1 parent e53f518 commit 6b891da

17 files changed

Lines changed: 871 additions & 11 deletions

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ def __init__(
671671
self.pod_ip: str = None
672672
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
673673
self.disable_custom_all_reduce: bool = False
674+
self.enable_flashinfer_allreduce_fusion: bool = False
674675
for key, value in args.items():
675676
if hasattr(self, key):
676677
setattr(self, key, value)

fastdeploy/engine/args_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ class EngineArgs:
274274
Flag to disable the custom all-reduce kernel.
275275
"""
276276

277+
enable_flashinfer_allreduce_fusion: bool = False
278+
"""
279+
Flag to enable all reduce fusion kernel in flashinfer.
280+
"""
281+
277282
use_internode_ll_two_stage: bool = False
278283
"""
279284
Flag to use the internode_ll_two_stage kernel.
@@ -995,6 +1000,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
9951000
default=EngineArgs.disable_custom_all_reduce,
9961001
help="Flag to disable custom all-reduce.",
9971002
)
1003+
parallel_group.add_argument(
1004+
"--enable-flashinfer-allreduce-fusion",
1005+
action="store_true",
1006+
default=EngineArgs.enable_flashinfer_allreduce_fusion,
1007+
help="Flag to enable all reduce fusion kernel in flashinfer.",
1008+
)
9981009
parallel_group.add_argument(
9991010
"--use-internode-ll-two-stage",
10001011
action="store_true",

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,7 @@ def _start_worker_service(self):
25032503
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
25042504
"enable_entropy": self.cfg.model_config.enable_entropy,
25052505
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
2506+
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
25062507
}
25072508
for worker_flag, value in worker_store_true_flag.items():
25082509
if value:

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ def _start_worker_service(self):
656656
"enable_entropy": self.cfg.model_config.enable_entropy,
657657
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
658658
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
659+
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
659660
}
660661
for worker_flag, value in worker_store_true_flag.items():
661662
if value:
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple
18+
19+
import paddle
20+
import paddle.distributed as dist
21+
22+
from fastdeploy.config import FDConfig
23+
from fastdeploy.model_executor.utils import has_flashinfer
24+
from fastdeploy.utils import get_logger
25+
26+
logger = get_logger("flashinfer", "flashinfer.log")
27+
28+
_flashinfer_comm = None
29+
_workspace_manager = None
30+
31+
32+
def _get_flashinfer_comm():
33+
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
34+
global _flashinfer_comm
35+
if _flashinfer_comm is not None:
36+
return _flashinfer_comm
37+
if has_flashinfer():
38+
try:
39+
with paddle.use_compat_guard(enable=True, scope={"flashinfer"}):
40+
import flashinfer.comm as comm
41+
42+
_flashinfer_comm = comm
43+
except ImportError:
44+
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")
45+
return _flashinfer_comm
46+
47+
48+
class FlashInferWorkspaceManager:
49+
def __init__(self):
50+
self.workspace_tensor = None
51+
self.ipc_handles = None
52+
self.world_size = None
53+
self.rank = None
54+
self.initialized = False
55+
56+
def initialize(
57+
self,
58+
world_size: int,
59+
rank: int,
60+
max_token_num: int,
61+
hidden_dim: int,
62+
group=None,
63+
use_fp32_lamport: bool = False,
64+
):
65+
"""Initialize workspace"""
66+
if self.initialized and self.world_size == world_size:
67+
return
68+
69+
comm = _get_flashinfer_comm()
70+
if comm is None:
71+
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
72+
return
73+
74+
self.cleanup()
75+
76+
self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
77+
rank,
78+
world_size,
79+
max_token_num,
80+
hidden_dim,
81+
group=group,
82+
use_fp32_lamport=use_fp32_lamport,
83+
)
84+
85+
self.world_size = world_size
86+
self.rank = rank
87+
self.initialized = True
88+
89+
logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")
90+
91+
def cleanup(self):
92+
"""Clean up workspace"""
93+
if self.initialized and self.ipc_handles is not None:
94+
try:
95+
comm = _get_flashinfer_comm()
96+
if comm is not None:
97+
comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
98+
except Exception as e:
99+
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
100+
finally:
101+
self.workspace_tensor = None
102+
self.ipc_handles = None
103+
self.initialized = False
104+
105+
106+
_workspace_manager = FlashInferWorkspaceManager()
107+
108+
109+
def ensure_workspace_initialized(
110+
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
111+
):
112+
"""Ensure workspace is initialized"""
113+
comm = _get_flashinfer_comm()
114+
if not has_flashinfer() or comm is None:
115+
return False
116+
117+
assert fd_config is not None
118+
world_size = fd_config.parallel_config.tensor_parallel_size
119+
if world_size <= 1:
120+
return False
121+
122+
rank = dist.get_rank()
123+
124+
if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
125+
_workspace_manager.initialize(
126+
world_size=world_size,
127+
rank=rank,
128+
max_token_num=max_token_num,
129+
hidden_dim=hidden_dim,
130+
use_fp32_lamport=use_fp32_lamport,
131+
)
132+
133+
return _workspace_manager.initialized
134+
135+
136+
def flashinfer_allreduce_residual_rmsnorm(
137+
fd_config: FDConfig,
138+
input_tensor: paddle.Tensor,
139+
residual: paddle.Tensor,
140+
weight: paddle.Tensor,
141+
eps: float = 1e-6,
142+
max_token_num: int = 2048,
143+
use_oneshot: Optional[bool] = None,
144+
trigger_completion_at_end: bool = False,
145+
fp32_acc: bool = False,
146+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
147+
"""
148+
Use FlashInfer's fused allreduce + residual + RMS norm operation
149+
"""
150+
comm = _get_flashinfer_comm()
151+
if not has_flashinfer() or comm is None:
152+
logger.debug("FlashInfer not available, falling back to standard " "implementation")
153+
return None, None
154+
155+
assert fd_config is not None
156+
world_size = fd_config.parallel_config.tensor_parallel_size
157+
if world_size <= 1:
158+
logger.debug("Single GPU, no need for allreduce fusion")
159+
return None, None
160+
161+
assert input_tensor.shape[0] <= max_token_num
162+
163+
if not ensure_workspace_initialized(
164+
fd_config=fd_config,
165+
max_token_num=max_token_num,
166+
hidden_dim=input_tensor.shape[-1],
167+
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
168+
):
169+
logger.debug("FlashInfer workspace not available")
170+
return None, None
171+
172+
token_num, hidden_dim = input_tensor.shape
173+
174+
residual_out = paddle.empty_like(residual)
175+
norm_out = paddle.empty_like(input_tensor)
176+
# support empty tensor
177+
if input_tensor.shape[0] == 0:
178+
return norm_out, residual_out
179+
comm.trtllm_allreduce_fusion(
180+
allreduce_in=input_tensor,
181+
world_size=world_size,
182+
world_rank=dist.get_rank(),
183+
token_num=token_num,
184+
hidden_dim=hidden_dim,
185+
workspace_ptrs=_workspace_manager.workspace_tensor,
186+
launch_with_pdl=True,
187+
use_oneshot=use_oneshot,
188+
trigger_completion_at_end=trigger_completion_at_end,
189+
fp32_acc=fp32_acc,
190+
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
191+
allreduce_out=None,
192+
residual_in=residual,
193+
residual_out=residual_out,
194+
norm_out=norm_out,
195+
quant_out=None,
196+
scale_out=None,
197+
rms_gamma=weight,
198+
rms_eps=eps,
199+
scale_factor=None,
200+
layout_code=None,
201+
)
202+
203+
return norm_out, residual_out
204+
205+
206+
def cleanup_flashinfer_workspace():
207+
global _workspace_manager
208+
if _workspace_manager is not None:
209+
_workspace_manager.cleanup()

fastdeploy/model_executor/layers/linear.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def __init__(
854854
skip_quant: bool = False,
855855
weight_dtype: str = "",
856856
layer_id: int = -1,
857+
enable_all_reduce_fusion: bool = None,
857858
):
858859
"""
859860
Initialize a linear layer with additional parameters for inference and quantization.
@@ -865,9 +866,17 @@ def __init__(
865866
input_size (int): Number of input features. Defaults to None.
866867
output_size (int): Number of output features. Defaults to None.
867868
with_bias (bool): Whether to include bias or not. Defaults to False.
868-
skip_quant (bool): Whether to skip quantization. Defaults to False.
869+
skip_quant (bool): Whether to skip quantization or not. Defaults to False.
870+
enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion.
871+
If None, it is determined by the config flag and prefix. Defaults to None.
869872
"""
870873
self.fd_config = fd_config
874+
if enable_all_reduce_fusion is None:
875+
self.enable_all_reduce_fusion = False
876+
else:
877+
self.enable_all_reduce_fusion = (
878+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion
879+
)
871880
self.ep_size = fd_config.parallel_config.expert_parallel_size
872881
self.tp_size = fd_config.parallel_config.tensor_parallel_size
873882
self.tp_group = fd_config.parallel_config.tp_group
@@ -945,7 +954,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
945954

946955
out = self.quant_method.apply(self, x)
947956

948-
if self.reduce_results and self.tp_size > 1:
957+
need_tp_all_reduce = (
958+
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
959+
)
960+
if need_tp_all_reduce:
949961
out = tensor_model_parallel_all_reduce(out, self.tp_group)
950962

951963
return out

fastdeploy/model_executor/layers/normalization.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_batch_invariant_mode_enabled,
3636
rms_norm_batch_invariant,
3737
)
38+
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
3839
from .utils import get_tensor, modules_to_convert
3940

4041

@@ -122,6 +123,10 @@ def __init__(
122123
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
123124
self.tp_group = self.fd_config.parallel_config.tp_group
124125
is_input_norm = prefix.endswith(".input_layernorm")
126+
self.enable_all_reduce_fusion = (
127+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
128+
)
129+
125130
self.is_last_norm = prefix.endswith(".norm")
126131
self.split_x = (
127132
self.fd_config.parallel_config.use_sequence_parallel_moe
@@ -240,6 +245,12 @@ def forward(
240245
norm_out = rms_norm(x, self.weight, self.eps)
241246
return norm_out.astype(x_dtype), residual_out
242247
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
248+
# enable trtllm all reduce fusion
249+
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
250+
norm_out = flashinfer_allreduce_residual_rmsnorm(
251+
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
252+
)
253+
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
243254
else:
244255
if is_batch_invariant_mode_enabled():
245256
# M-invariant path: per-row Triton kernel, no cross-row reduction

fastdeploy/model_executor/layers/quantization/mxfp4.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
"""
1616

17-
import importlib
18-
import importlib.util
1917
import math
2018
from enum import Enum
2119
from typing import Callable, Optional
@@ -25,11 +23,12 @@
2523

2624
from fastdeploy import envs
2725
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
28-
from fastdeploy.model_executor.utils import set_weight_attrs
26+
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
2927
from fastdeploy.platforms import current_platform
3028

3129
if current_platform.is_cuda():
3230
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
31+
3332
from fastdeploy.utils import get_logger
3433

3534
from ..moe import FusedMoE
@@ -59,10 +58,6 @@ def check_device_capability(num):
5958
return False
6059

6160

62-
def has_flashinfer():
63-
return importlib.util.find_spec("flashinfer") is not None
64-
65-
6661
def round_up(a, b):
6762
return ((a + b - 1) // b) * b
6863

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __init__(
130130
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
131131
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
132132
self.tp_group = fd_config.parallel_config.tp_group
133-
134133
self.use_ep = self.expert_parallel_size > 1
135134
self.use_tp = self.tensor_parallel_size > 1
136135

@@ -229,6 +228,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
229228
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
230229
output_size=fd_config.model_config.hidden_size,
231230
layer_id=layer_id,
231+
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
232232
)
233233

234234
self.attn = Attention(

fastdeploy/model_executor/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""
1616

17+
import importlib
18+
import importlib.util
1719
import os
1820
import re
1921
from collections.abc import Mapping
@@ -553,6 +555,10 @@ def fn(loaded_weight_name, is_moe):
553555
return fn
554556

555557

558+
def has_flashinfer():
559+
return importlib.util.find_spec("flashinfer") is not None
560+
561+
556562
@cache
557563
def get_sm_version():
558564
if paddle.cuda.is_available():

0 commit comments

Comments
 (0)