|
| 1 | +""" |
| 2 | +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import Optional, Tuple |
| 18 | + |
| 19 | +import paddle |
| 20 | +import paddle.distributed as dist |
| 21 | + |
| 22 | +from fastdeploy.config import FDConfig |
| 23 | +from fastdeploy.model_executor.utils import has_flashinfer |
| 24 | +from fastdeploy.utils import get_logger |
| 25 | + |
| 26 | +logger = get_logger("flashinfer", "flashinfer.log") |
| 27 | + |
| 28 | +_flashinfer_comm = None |
| 29 | +_workspace_manager = None |
| 30 | + |
| 31 | + |
| 32 | +def _get_flashinfer_comm(): |
| 33 | + """Lazily import flashinfer.comm to avoid side effects at module load time.""" |
| 34 | + global _flashinfer_comm |
| 35 | + if _flashinfer_comm is not None: |
| 36 | + return _flashinfer_comm |
| 37 | + if has_flashinfer(): |
| 38 | + try: |
| 39 | + with paddle.use_compat_guard(enable=True, scope={"flashinfer"}): |
| 40 | + import flashinfer.comm as comm |
| 41 | + |
| 42 | + _flashinfer_comm = comm |
| 43 | + except ImportError: |
| 44 | + logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") |
| 45 | + return _flashinfer_comm |
| 46 | + |
| 47 | + |
| 48 | +class FlashInferWorkspaceManager: |
| 49 | + def __init__(self): |
| 50 | + self.workspace_tensor = None |
| 51 | + self.ipc_handles = None |
| 52 | + self.world_size = None |
| 53 | + self.rank = None |
| 54 | + self.initialized = False |
| 55 | + |
| 56 | + def initialize( |
| 57 | + self, |
| 58 | + world_size: int, |
| 59 | + rank: int, |
| 60 | + max_token_num: int, |
| 61 | + hidden_dim: int, |
| 62 | + group=None, |
| 63 | + use_fp32_lamport: bool = False, |
| 64 | + ): |
| 65 | + """Initialize workspace""" |
| 66 | + if self.initialized and self.world_size == world_size: |
| 67 | + return |
| 68 | + |
| 69 | + comm = _get_flashinfer_comm() |
| 70 | + if comm is None: |
| 71 | + logger.warning("FlashInfer comm not available, skipping workspace " "initialization") |
| 72 | + return |
| 73 | + |
| 74 | + self.cleanup() |
| 75 | + |
| 76 | + self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( |
| 77 | + rank, |
| 78 | + world_size, |
| 79 | + max_token_num, |
| 80 | + hidden_dim, |
| 81 | + group=group, |
| 82 | + use_fp32_lamport=use_fp32_lamport, |
| 83 | + ) |
| 84 | + |
| 85 | + self.world_size = world_size |
| 86 | + self.rank = rank |
| 87 | + self.initialized = True |
| 88 | + |
| 89 | + logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") |
| 90 | + |
| 91 | + def cleanup(self): |
| 92 | + """Clean up workspace""" |
| 93 | + if self.initialized and self.ipc_handles is not None: |
| 94 | + try: |
| 95 | + comm = _get_flashinfer_comm() |
| 96 | + if comm is not None: |
| 97 | + comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) |
| 98 | + except Exception as e: |
| 99 | + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") |
| 100 | + finally: |
| 101 | + self.workspace_tensor = None |
| 102 | + self.ipc_handles = None |
| 103 | + self.initialized = False |
| 104 | + |
| 105 | + |
| 106 | +_workspace_manager = FlashInferWorkspaceManager() |
| 107 | + |
| 108 | + |
| 109 | +def ensure_workspace_initialized( |
| 110 | + fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False |
| 111 | +): |
| 112 | + """Ensure workspace is initialized""" |
| 113 | + comm = _get_flashinfer_comm() |
| 114 | + if not has_flashinfer() or comm is None: |
| 115 | + return False |
| 116 | + |
| 117 | + assert fd_config is not None |
| 118 | + world_size = fd_config.parallel_config.tensor_parallel_size |
| 119 | + if world_size <= 1: |
| 120 | + return False |
| 121 | + |
| 122 | + rank = dist.get_rank() |
| 123 | + |
| 124 | + if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: |
| 125 | + _workspace_manager.initialize( |
| 126 | + world_size=world_size, |
| 127 | + rank=rank, |
| 128 | + max_token_num=max_token_num, |
| 129 | + hidden_dim=hidden_dim, |
| 130 | + use_fp32_lamport=use_fp32_lamport, |
| 131 | + ) |
| 132 | + |
| 133 | + return _workspace_manager.initialized |
| 134 | + |
| 135 | + |
| 136 | +def flashinfer_allreduce_residual_rmsnorm( |
| 137 | + fd_config: FDConfig, |
| 138 | + input_tensor: paddle.Tensor, |
| 139 | + residual: paddle.Tensor, |
| 140 | + weight: paddle.Tensor, |
| 141 | + eps: float = 1e-6, |
| 142 | + max_token_num: int = 2048, |
| 143 | + use_oneshot: Optional[bool] = None, |
| 144 | + trigger_completion_at_end: bool = False, |
| 145 | + fp32_acc: bool = False, |
| 146 | +) -> Tuple[paddle.Tensor, paddle.Tensor]: |
| 147 | + """ |
| 148 | + Use FlashInfer's fused allreduce + residual + RMS norm operation |
| 149 | + """ |
| 150 | + comm = _get_flashinfer_comm() |
| 151 | + if not has_flashinfer() or comm is None: |
| 152 | + logger.debug("FlashInfer not available, falling back to standard " "implementation") |
| 153 | + return None, None |
| 154 | + |
| 155 | + assert fd_config is not None |
| 156 | + world_size = fd_config.parallel_config.tensor_parallel_size |
| 157 | + if world_size <= 1: |
| 158 | + logger.debug("Single GPU, no need for allreduce fusion") |
| 159 | + return None, None |
| 160 | + |
| 161 | + assert input_tensor.shape[0] <= max_token_num |
| 162 | + |
| 163 | + if not ensure_workspace_initialized( |
| 164 | + fd_config=fd_config, |
| 165 | + max_token_num=max_token_num, |
| 166 | + hidden_dim=input_tensor.shape[-1], |
| 167 | + use_fp32_lamport=(input_tensor.dtype == paddle.float32), |
| 168 | + ): |
| 169 | + logger.debug("FlashInfer workspace not available") |
| 170 | + return None, None |
| 171 | + |
| 172 | + token_num, hidden_dim = input_tensor.shape |
| 173 | + |
| 174 | + residual_out = paddle.empty_like(residual) |
| 175 | + norm_out = paddle.empty_like(input_tensor) |
| 176 | + # support empty tensor |
| 177 | + if input_tensor.shape[0] == 0: |
| 178 | + return norm_out, residual_out |
| 179 | + comm.trtllm_allreduce_fusion( |
| 180 | + allreduce_in=input_tensor, |
| 181 | + world_size=world_size, |
| 182 | + world_rank=dist.get_rank(), |
| 183 | + token_num=token_num, |
| 184 | + hidden_dim=hidden_dim, |
| 185 | + workspace_ptrs=_workspace_manager.workspace_tensor, |
| 186 | + launch_with_pdl=True, |
| 187 | + use_oneshot=use_oneshot, |
| 188 | + trigger_completion_at_end=trigger_completion_at_end, |
| 189 | + fp32_acc=fp32_acc, |
| 190 | + pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), |
| 191 | + allreduce_out=None, |
| 192 | + residual_in=residual, |
| 193 | + residual_out=residual_out, |
| 194 | + norm_out=norm_out, |
| 195 | + quant_out=None, |
| 196 | + scale_out=None, |
| 197 | + rms_gamma=weight, |
| 198 | + rms_eps=eps, |
| 199 | + scale_factor=None, |
| 200 | + layout_code=None, |
| 201 | + ) |
| 202 | + |
| 203 | + return norm_out, residual_out |
| 204 | + |
| 205 | + |
| 206 | +def cleanup_flashinfer_workspace(): |
| 207 | + global _workspace_manager |
| 208 | + if _workspace_manager is not None: |
| 209 | + _workspace_manager.cleanup() |
0 commit comments