Skip to content

Commit 8de7924

Browse files
committed
add mtp attn mask and layer mask
1 parent 86ec329 commit 8de7924

8 files changed

Lines changed: 788 additions & 98 deletions

File tree

download.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import time
3+
from aistudio_sdk.snapshot_download import snapshot_download
4+
5+
os.environ["AISTUDIO_ACCESS_TOKEN"] = "2803c8bd3c444e6bacc7286c84acee55bc2d4cd5"
6+
os.environ["http_proxy"] = "http://agent.baidu.com:8188"
7+
os.environ["https_proxy"] = "http://agent.baidu.com:8188"
8+
9+
# 调用什么模型,直接从
10+
MODEL_LIST = [
11+
"PaddleFormers/tiny-random-glm4moe-bf16",
12+
]
13+
14+
# 模型存储的基本位置
15+
BASE_DIR = "/home/models/"
16+
17+
COMMON_ARGS = {
18+
"revision": "master",
19+
}
20+
21+
def download_model(repo_id: str, max_retries: int = 3):
22+
local_dir = os.path.join(BASE_DIR, repo_id)
23+
os.makedirs(local_dir, exist_ok=True)
24+
25+
for attempt in range(1, max_retries + 1):
26+
try:
27+
res = snapshot_download(repo_id=repo_id, local_dir=local_dir, **COMMON_ARGS)
28+
return True
29+
except Exception as e:
30+
if attempt < max_retries:
31+
time.sleep(5)
32+
else:
33+
print(f" Skip: {repo_id} after {max_retries} retries")
34+
return False
35+
36+
37+
if __name__ == "__main__":
38+
for model in MODEL_LIST:
39+
print(model,"download start....")
40+
download_model(model)
41+
print(model,"download end !")

estimation_output.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"num_train_epochs": 1, "max_steps": 10, "train_tokens": 707, "global_batch_size": 1, "gradient_accumulation_steps": 1, "warmup_steps": 1, "per_device_train_batch_size": 1, "tensor_model_parallel_size": -1, "pipeline_model_parallel_size": -1, "sharding_parallel_size": -1, "seed": 23, "num_samples_each_epoch": 6000000, "max_seq_len": 8192, "valid": true, "train_samples": 10, "estimate_samples": 10, "actual_train_samples": 10, "skip_samples": 0, "num_of_gpus": -1}

examples/config/sft/full.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ template_backend: custom
1212
template: qwen3
1313

1414
### model
15-
model_name_or_path: Qwen/Qwen3-0.6B-Base
15+
model_name_or_path: /home/models/PaddleFormers/tiny-random-glm4moe-bf16/
1616
_attn_implementation: flashmask
1717

1818
### finetuning
@@ -31,7 +31,7 @@ evaluation_strategy: steps
3131
save_steps: 100
3232
save_strategy: steps
3333
logging_steps: 1
34-
gradient_accumulation_steps: 4
34+
gradient_accumulation_steps: 1
3535
logging_dir: ./vdl_log
3636
output_dir: ./checkpoints/qwen3-sft-full
3737
disable_tqdm: true

learn.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# ===================================================================
2+
# 原始 learn.py 代码(预训练 Dataset 类方法版本,已注释)
3+
# ===================================================================
4+
5+
def sample_independent_mtp_hidden_inputs_mask_(self, mtp_hidden_inputs_mask_tmp, mtp_ids_input_ids, eos_token_id):
6+
""" mtp输入序列与eos 融合的hidden mask乘0 """
7+
# 找到等于 eos_token_id 的位置
8+
eos_positions = np.where(mtp_ids_input_ids == eos_token_id)
9+
mtp_hidden_inputs_mask_tmp[0][eos_positions] = 0
10+
return mtp_hidden_inputs_mask_tmp
11+
12+
# 数据流在mtp情况下,新增两个输入:
13+
# "mtp_startend_row_indices_all",
14+
# "mtp_hidden_inputs_mask_all",
15+
16+
def get_mtp_inputs_info(self, task_id, random_doc, ids):
17+
""" 得到MTP的 mtp_startend_row_indices_all, mtp_hidden_inputs_mask_all 信息"""
18+
19+
mtp_startend_row_indices_all = []
20+
mtp_hidden_inputs_mask_all = []
21+
ids_input = ids[:-1] # 对齐组网输入
22+
# !!构造方式同ernie5_moe.modeling_pp.MTPmmLayerPipe.forward:不足max_seq_len时从原始input_ids里面取mtp_input_id
23+
ids_mtp = ids_input[-self.multi_token_pred_depth :]
24+
ids_ori = ids_input[: -self.multi_token_pred_depth]
25+
26+
for mtp_idx in range(self.multi_token_pred_depth):
27+
# 构造mtp_startend_row_indices
28+
mtp_startend_row_indices_tmp = np.expand_dims(
29+
np.stack(
30+
[
31+
np.full((self.seqlen - 1,), self.seqlen - 1, dtype=np.int32),
32+
np.arange(self.seqlen - 1, dtype=np.int32),
33+
],
34+
axis=1,
35+
),
36+
0,
37+
) # [1, seqlen-1, 2]
38+
# 构造mtp_hidden_inputs_mask
39+
mtp_hidden_inputs_mask_tmp = np.ones([1, self.seqlen - 1], dtype=np.int32) # [1, seqlen-1]
40+
41+
mtp_ids_input_ids = np.concatenate([ids_ori[(mtp_idx + 1) :], ids_mtp[: (mtp_idx + 1)]])
42+
43+
if (mtp_ids_input_ids.shape[0] > 0 and self.document_mask_prob_text is not None
44+
and task_id in [0] and random_doc < self.document_mask_prob_text):
45+
# 样本互相不可见
46+
assert self.inbatch_sft is False, "document_mask_prob_text not support inbatch_sft"
47+
sample_independent_startend_row_indices_(
48+
mtp_startend_row_indices_tmp,
49+
mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids,
50+
self.eos_token_id,
51+
)
52+
mtp_hidden_inputs_mask_tmp = self.sample_independent_mtp_hidden_inputs_mask_(
53+
mtp_hidden_inputs_mask_tmp,
54+
mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids,
55+
self.eos_token_id,
56+
)
57+
elif mtp_ids_input_ids.shape[0] > 0 and task_id in [0] and self.inbatch_sft:
58+
# 样本互相不可见
59+
assert self.document_mask_prob_text is None, f"{self.document_mask_prob_text}"
60+
sample_independent_startend_row_indices_(
61+
mtp_startend_row_indices_tmp,
62+
mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids,
63+
self.eos_token_id,
64+
)
65+
mtp_hidden_inputs_mask_tmp = self.sample_independent_mtp_hidden_inputs_mask_(
66+
mtp_hidden_inputs_mask_tmp,
67+
mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids,
68+
self.eos_token_id,
69+
)
70+
71+
mtp_startend_row_indices_all.append(deepcopy(mtp_startend_row_indices_tmp))
72+
mtp_hidden_inputs_mask_all.append(deepcopy(mtp_hidden_inputs_mask_tmp))
73+
if len(mtp_startend_row_indices_all) > 0:
74+
mtp_startend_row_indices_all = np.concatenate(mtp_startend_row_indices_all, axis=0)
75+
mtp_hidden_inputs_mask_all = np.concatenate(mtp_hidden_inputs_mask_all, axis=0)
76+
else:
77+
mtp_startend_row_indices_all = None
78+
mtp_hidden_inputs_mask_all = None
79+
80+
return mtp_startend_row_indices_all, mtp_hidden_inputs_mask_all
81+
82+
# 得到的"mtp_startend_row_indices_all", "mtp_hidden_inputs_mask_all" 也需要进行padding:
83+
def pad_mtp_hidden_inputs_mask(mtp_hidden_inputs_mask_all, max_seq_len):
84+
"""pad_mtp_hidden_inputs_mask"""
85+
head, l = mtp_hidden_inputs_mask_all.shape
86+
assert head == 1, "预训练只训练了mtp=1的情况"
87+
if l == max_seq_len:
88+
return mtp_hidden_inputs_mask_all
89+
elif l < max_seq_len:
90+
pad_l = max_seq_len - l
91+
padding = np.ones((head, pad_l), dtype=mtp_hidden_inputs_mask_all.dtype)
92+
return np.concatenate([mtp_hidden_inputs_mask_all, padding], axis=1)
93+
else:
94+
raise Exception
95+
96+
97+
# ===================================================================
98+
# collate.py 对应翻译版(独立函数,无 self,对齐 collate.py 风格)
99+
# ===================================================================
100+
#
101+
# 变量对应关系:
102+
# learn.py (self.xxx) collate.py 参数
103+
# ---------------------------------------------------
104+
# self.multi_token_pred_depth <-> mtp_depth
105+
# self.seqlen - 1 <-> total_len(实际 token 数,padding 前)
106+
# max_seq_len(含 MTP 扩展) <-> max_seq_len(已 += mtp_depth)
107+
# self.eos_token_id <-> eos_token_id
108+
# self.inbatch_sft / document <-> not use_global_causal_attn
109+
# ids(单条序列) <-> batch_token_ids(多条打包序列的列表)
110+
#
111+
# 关键差异:
112+
# learn.py 通过 EOS token 检测文档边界;
113+
# collate.py 的 packing 中文档边界已由 batch_token_ids 的分组隐式给出,
114+
# use_global_causal_attn=False 时直接按分组做分块 causal mask,
115+
# 无需再扫描 EOS。gen_mtp_layer_mask 仍保留 eos_token_id 参数以兼容
116+
# 有 EOS 显式标记的场景。
117+
# ===================================================================
118+
119+
import numpy as np
120+
from typing import List
121+
122+
123+
def gen_mtp_attn_mask(
124+
batch_token_ids: List[List[int]],
125+
max_seq_len: int,
126+
mtp_depth: int,
127+
use_global_causal_attn: bool,
128+
) -> np.ndarray:
129+
"""生成 MTP 每一层的 attention mask(二维矩阵形式)。
130+
131+
核心逻辑:每层 MTP 的输入序列相对原始序列右移了 (mtp_idx+1) 步,
132+
因此文档块边界也随之左移 (mtp_idx+1),每层的 mask 矩阵不同。
133+
134+
对应 learn.py get_mtp_inputs_info 中 mtp_startend_row_indices 的构造逻辑:
135+
- 初始化为全局因果 mask(end_row = seqlen-1)
136+
- 若启用文档隔离,调用 sample_independent_startend_row_indices_
137+
以 mtp_ids_input_ids(移位序列)的 EOS 重新划分块边界
138+
139+
Args:
140+
batch_token_ids: 打包的多个 sequence 的 token ids 列表。
141+
collate.py 中文档边界由分组隐式给出,无需扫描 EOS。
142+
max_seq_len: padding 后序列长度,已含 mtp_depth 扩展。
143+
mtp_depth: MTP 预测层数 D。对应 self.multi_token_pred_depth。
144+
use_global_causal_attn: True 时全局因果(单块);
145+
False 时分块因果,块边界随层数左移。
146+
147+
Returns:
148+
np.ndarray, shape [mtp_depth, 1, max_seq_len, max_seq_len], dtype=float32。
149+
与 gen_self_attn_mask 返回的 [1, 1, max_seq_len, max_seq_len] 格式一致,
150+
第0维扩展为 mtp_depth。
151+
"""
152+
total_len = sum(len(ids) for ids in batch_token_ids)
153+
154+
# 原始文档块边界(exclusive),不含最后一个 total_len
155+
# 例:batch_token_ids=[[A,B,EOS],[D,E,F]] → internal_boundaries=[3]
156+
internal_boundaries = []
157+
offset = 0
158+
for ids in batch_token_ids[:-1]:
159+
offset += len(ids)
160+
internal_boundaries.append(offset)
161+
162+
result = []
163+
for mtp_idx in range(mtp_depth):
164+
mask = np.zeros((max_seq_len, max_seq_len), dtype=np.float32)
165+
166+
if use_global_causal_attn:
167+
# 全局因果:整个序列是一个块
168+
b = np.tril(np.ones([total_len, total_len]))
169+
mask[:total_len, :total_len] = b
170+
else:
171+
# 分块因果,块边界左移 (mtp_idx+1)
172+
# 原理:MTP层k的输入是原始序列右移k+1步,
173+
# 原来在位置 b 的文档边界,在移位视角下出现在位置 b-(k+1)
174+
shift = mtp_idx + 1
175+
shifted_boundaries = [b - shift for b in internal_boundaries if b - shift > 0]
176+
# 最后一个边界始终是 total_len(最后文档块延伸到序列末尾)
177+
all_boundaries = shifted_boundaries + [total_len]
178+
179+
prev = 0
180+
for boundary in all_boundaries:
181+
cur_len = boundary - prev
182+
if cur_len > 0:
183+
# 在 [prev:boundary, prev:boundary] 填下三角因果矩阵
184+
mask[prev:boundary, prev:boundary] = np.tril(np.ones([cur_len, cur_len]))
185+
prev = boundary
186+
187+
result.append(mask)
188+
189+
# [mtp_depth, max_seq_len, max_seq_len] → [mtp_depth, 1, max_seq_len, max_seq_len]
190+
return np.stack(result, axis=0)[:, None, :, :]
191+
192+
193+
def gen_mtp_attn_mask_startend_row_indices(
194+
batch_token_ids: List[List[int]],
195+
max_seq_len: int,
196+
mtp_depth: int,
197+
use_global_causal_attn: bool,
198+
) -> np.ndarray:
199+
"""生成 MTP 每一层的 attention mask(压缩一维 startend_row_indices 格式)。
200+
201+
是 gen_mtp_attn_mask 的压缩版本,与 gen_attn_mask_startend_row_indices 格式对齐。
202+
每个位置存储其所在块的 end_row(exclusive),由 flash attention 内核推断因果关系。
203+
204+
块边界规律(与 gen_mtp_attn_mask 相同):
205+
层k 的块边界 = 原始边界 - (k+1)
206+
207+
Args:
208+
同 gen_mtp_attn_mask。
209+
210+
Returns:
211+
np.ndarray, shape [mtp_depth, 1, max_seq_len, 1], dtype=int32。
212+
与 gen_attn_mask_startend_row_indices 返回的 [1, 1, max_seq_len, 1] 格式一致,
213+
第0维扩展为 mtp_depth。
214+
"""
215+
total_len = sum(len(ids) for ids in batch_token_ids)
216+
217+
internal_boundaries = []
218+
offset = 0
219+
for ids in batch_token_ids[:-1]:
220+
offset += len(ids)
221+
internal_boundaries.append(offset)
222+
223+
result = []
224+
for mtp_idx in range(mtp_depth):
225+
if use_global_causal_attn:
226+
indices = [total_len] * total_len
227+
else:
228+
shift = mtp_idx + 1
229+
shifted_boundaries = [b - shift for b in internal_boundaries if b - shift > 0]
230+
all_boundaries = shifted_boundaries + [total_len]
231+
232+
indices = []
233+
prev = 0
234+
for boundary in all_boundaries:
235+
cur_len = boundary - prev
236+
# 该块内所有位置的 end_row 均为 boundary(块的 exclusive 末尾)
237+
indices.extend([boundary] * cur_len)
238+
prev = boundary
239+
240+
# padding 区域:与 gen_attn_mask_startend_row_indices 保持一致
241+
# range(total_len, max_seq_len) 让 padding 位置的 end_row 递增
242+
if total_len < max_seq_len:
243+
indices.extend(list(range(total_len, max_seq_len)))
244+
245+
result.append(indices)
246+
247+
# [mtp_depth, max_seq_len] → [mtp_depth, 1, max_seq_len, 1]
248+
return np.array(result, dtype=np.int32)[:, None, :, None]
249+
250+
251+
def gen_mtp_layer_mask(
252+
batch_token_ids: List[List[int]],
253+
max_seq_len: int,
254+
mtp_depth: int,
255+
eos_token_id: int = None,
256+
) -> np.ndarray:
257+
"""生成 MTP 每一层的 hidden inputs mask(待后续讨论,暂保留原有实现)。"""
258+
all_token_ids = np.concatenate([np.array(ids, dtype=np.int32) for ids in batch_token_ids])
259+
total_len = len(all_token_ids)
260+
ids_mtp = all_token_ids[-mtp_depth:]
261+
ids_ori = all_token_ids[:-mtp_depth]
262+
263+
result = []
264+
for mtp_idx in range(mtp_depth):
265+
mask = np.ones(total_len, dtype=np.int32)
266+
if eos_token_id is not None:
267+
mtp_ids_input_ids = np.concatenate([ids_ori[mtp_idx + 1:], ids_mtp[:mtp_idx + 1]])
268+
eos_positions = np.where(mtp_ids_input_ids == eos_token_id)[0]
269+
mask[eos_positions] = 0
270+
if total_len < max_seq_len:
271+
mask = np.concatenate([mask, np.ones(max_seq_len - total_len, dtype=np.int32)])
272+
result.append(mask)
273+
274+
return np.stack(result, axis=0) # [mtp_depth, max_seq_len]

paddleformers/cli/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"-" * 60
4444
+ "\n"
4545
+ "| Usage: |\n"
46-
+ "| paddleformers-cli train -h: model finetuning |\n"
46+
+ "| paddleformers-cli train -h: modeyl finetuning |\n"
4747
+ "| paddleformers-cli export -h: model export |\n"
4848
+ "| paddleformers-cli version: show version info |\n"
4949
+ "| paddleformers-cli help: show helping info |\n"

0 commit comments

Comments
 (0)