Skip to content

Commit 27b00cf

Browse files
authored
[XPU] glm-4.5-air (#7071)
1 parent 26c47c2 commit 27b00cf

9 files changed

Lines changed: 32 additions & 18 deletions

File tree

custom_ops/xpu_ops/download_dependencies.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ if [ "$1" == "stable" ]; then
1515
version_xvllm="20251017"
1616
version_xtdk="3.4.0.1"
1717
else
18-
version_xvllm="20260407"
18+
version_xvllm="latest"
1919
version_xtdk="3.6.2.1"
2020
fi
2121

custom_ops/xpu_ops/src/ops/block_attn.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
156156
rope_head_dim = rotary_embs.dims()[4];
157157
}
158158
std::string pos_emb_type;
159-
if (use_neox_rotary_style == true) {
159+
if (use_neox_rotary_style) {
160160
pos_emb_type = "NEOX";
161161
} else if (rope_head_dim == head_dim / 2) {
162162
pos_emb_type = "HALF_HEAD_DIM";
@@ -342,12 +342,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
342342
value_cache.data<cdata_t>())),
343343
vsl.usual_lod_vp, // seq_lod
344344
vsl.slot_mapping_vp, // real_batch
345+
prefix_lens_vp, // start_tokens
345346
param.batch_size, // batch_size
346347
1, // emb_batch_size
347348
rope_max_seqlen, // max_seqlen
348349
param.head_num,
349350
param.kv_head_num,
350351
param.head_dim,
352+
rope_head_dim,
351353
param.max_batch_size,
352354
block_size,
353355
max_block_per_seq,
@@ -586,7 +588,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
586588
ret = infer_ops::
587589
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
588590
xpu_ctx->x_context(),
589-
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
591+
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
592+
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
590593
reinterpret_cast<const float*>(
591594
rotary_embs.data<float>()), // rotary_pos_emb
592595
reinterpret_cast<const int*>(
@@ -598,14 +601,16 @@ std::vector<paddle::Tensor> BlockAttnKernel(
598601
key_cache.data<cdata_t>())),
599602
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
600603
value_cache.data<cdata_t>())),
601-
decoder_seq_lod_vp, // seq_lod
602-
decoder_batch_map_vp, // real_batch
603-
param.batch_size, // batch_size
604-
1, // emb_batch_size
605-
rope_max_seqlen, // max_seqlen
604+
decoder_seq_lod_vp, // seq_lod
605+
decoder_batch_map_vp, // real_batch
606+
decoder_context_len_cache_vp, // start_tokens
607+
param.batch_size, // batch_size
608+
1, // emb_batch_size
609+
rope_max_seqlen, // max_seqlen
606610
param.head_num,
607611
param.kv_head_num,
608612
param.head_dim,
613+
rope_head_dim,
609614
param.max_batch_size,
610615
block_size,
611616
max_block_per_seq,
@@ -806,6 +811,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
806811
param.head_num,
807812
param.kv_head_num,
808813
param.head_dim,
814+
rope_head_dim,
809815
param.max_batch_size,
810816
block_size,
811817
max_block_per_seq,

custom_ops/xpu_ops/src/ops/fused_noaux_tc.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,19 @@ std::vector<std::vector<int64_t>> FusedNoAuxTcInferShape(
7676
const float routed_scaling_factor) {
7777
std::vector<int64_t> topk_ids_shape = {gating_logits_shape[0], top_k};
7878
std::vector<int64_t> topk_weights_shape = {gating_logits_shape[0], top_k};
79-
return {gating_logits_shape, topk_ids_shape, topk_weights_shape};
79+
return {gating_logits_shape, topk_weights_shape, topk_ids_shape};
8080
}
8181

8282
std::vector<paddle::DataType> FusedNoAuxTcInferDtype(
8383
const paddle::DataType& gating_logits_dtype,
8484
const paddle::DataType& bias_dtype) {
8585
return {
86-
gating_logits_dtype, paddle::DataType::INT64, paddle::DataType::FLOAT32};
86+
gating_logits_dtype, paddle::DataType::FLOAT32, paddle::DataType::INT32};
8787
}
8888

8989
PD_BUILD_STATIC_OP(fused_noaux_tc)
9090
.Inputs({"gating_logits", "bias"})
91-
.Outputs({"gating_logits_out", "topk_ids", "topk_weights"})
91+
.Outputs({"gating_logits_out", "topk_weights", "topk_ids"})
9292
.Attrs({"n_group: int",
9393
"topk_group: int",
9494
"top_k: int",

fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def apply_tp(
313313
"""
314314
gate_out = gate(x.cast("float32"))
315315
if layer.topk_method == "noaux_tc":
316-
_, topk_idx, topk_weights = get_moe_scores(
316+
_, topk_weights, topk_idx = get_moe_scores(
317317
gate_out,
318318
layer.n_group,
319319
layer.topk_group,

fastdeploy/model_executor/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
6161
)
6262

6363
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
64-
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
64+
if extra_weight_attrs["output_dim"] is not None:
65+
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
6566

6667
set_weight_attrs(
6768
layer.weight,

fastdeploy/model_executor/layers/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
136136
logger.warning(f"Failed to parse quantization config normally ({e}), trying fallback")
137137
quant_config_name = args.quantization["quantization"]
138138
quantization_config["quantization"] = quant_config_name
139+
model_config.quantization_config = quantization_config
139140
# Special handling for Ernie models
140141
if quant_config_name == "wint4" and is_ernie:
141142
quantization_config["dense_quant_type"] = "wint8"

fastdeploy/model_executor/layers/rotary_embedding.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __call__(self, position_ids):
4444
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
4545
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
4646
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
47-
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
47+
if current_platform.is_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
4848
# shape: [B, S, D]
4949
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
5050
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
@@ -95,9 +95,14 @@ def __call__(self, position_ids):
9595
else:
9696
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
9797
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
98-
# shape: [B, S, D/2]
99-
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
100-
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
98+
if current_platform.is_xpu():
99+
# shape: [B, S, D]
100+
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
101+
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
102+
else:
103+
# shape: [B, S, D/2]
104+
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
105+
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
101106
# shape: [B, S, 1, D]
102107
emb = paddle.unsqueeze(emb, 2)
103108
rot_emb[0] = paddle.cos(emb)

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
fd_config=fd_config,
7474
prefix=f"{prefix}.up_gate_proj",
7575
input_size=fd_config.model_config.hidden_size,
76-
output_size=[intermediate_size, intermediate_size],
76+
output_sizes=[intermediate_size, intermediate_size],
7777
with_bias=False,
7878
)
7979

fastdeploy/worker/xpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ def _init_share_inputs(self, max_num_seqs: int):
990990
position_ids=tmp_position_ids,
991991
base=self.model_config.rope_theta,
992992
model_config=self.model_config,
993+
partial_rotary_factor=self.model_config.partial_rotary_factor,
993994
)
994995

995996
# Set block tables

0 commit comments

Comments
 (0)