|
| 1 | +# =================================================================== |
| 2 | +# 原始 learn.py 代码(预训练 Dataset 类方法版本,已注释) |
| 3 | +# =================================================================== |
| 4 | + |
| 5 | +def sample_independent_mtp_hidden_inputs_mask_(self, mtp_hidden_inputs_mask_tmp, mtp_ids_input_ids, eos_token_id): |
| 6 | + """ mtp输入序列与eos 融合的hidden mask乘0 """ |
| 7 | + # 找到等于 eos_token_id 的位置 |
| 8 | + eos_positions = np.where(mtp_ids_input_ids == eos_token_id) |
| 9 | + mtp_hidden_inputs_mask_tmp[0][eos_positions] = 0 |
| 10 | + return mtp_hidden_inputs_mask_tmp |
| 11 | + |
| 12 | +# 数据流在mtp情况下,新增两个输入: |
| 13 | +# "mtp_startend_row_indices_all", |
| 14 | +# "mtp_hidden_inputs_mask_all", |
| 15 | + |
| 16 | +def get_mtp_inputs_info(self, task_id, random_doc, ids): |
| 17 | + """ 得到MTP的 mtp_startend_row_indices_all, mtp_hidden_inputs_mask_all 信息""" |
| 18 | + |
| 19 | + mtp_startend_row_indices_all = [] |
| 20 | + mtp_hidden_inputs_mask_all = [] |
| 21 | + ids_input = ids[:-1] # 对齐组网输入 |
| 22 | + # !!构造方式同ernie5_moe.modeling_pp.MTPmmLayerPipe.forward:不足max_seq_len时从原始input_ids里面取mtp_input_id |
| 23 | + ids_mtp = ids_input[-self.multi_token_pred_depth :] |
| 24 | + ids_ori = ids_input[: -self.multi_token_pred_depth] |
| 25 | + |
| 26 | + for mtp_idx in range(self.multi_token_pred_depth): |
| 27 | + # 构造mtp_startend_row_indices |
| 28 | + mtp_startend_row_indices_tmp = np.expand_dims( |
| 29 | + np.stack( |
| 30 | + [ |
| 31 | + np.full((self.seqlen - 1,), self.seqlen - 1, dtype=np.int32), |
| 32 | + np.arange(self.seqlen - 1, dtype=np.int32), |
| 33 | + ], |
| 34 | + axis=1, |
| 35 | + ), |
| 36 | + 0, |
| 37 | + ) # [1, seqlen-1, 2] |
| 38 | + # 构造mtp_hidden_inputs_mask |
| 39 | + mtp_hidden_inputs_mask_tmp = np.ones([1, self.seqlen - 1], dtype=np.int32) # [1, seqlen-1] |
| 40 | + |
| 41 | + mtp_ids_input_ids = np.concatenate([ids_ori[(mtp_idx + 1) :], ids_mtp[: (mtp_idx + 1)]]) |
| 42 | + |
| 43 | + if (mtp_ids_input_ids.shape[0] > 0 and self.document_mask_prob_text is not None |
| 44 | + and task_id in [0] and random_doc < self.document_mask_prob_text): |
| 45 | + # 样本互相不可见 |
| 46 | + assert self.inbatch_sft is False, "document_mask_prob_text not support inbatch_sft" |
| 47 | + sample_independent_startend_row_indices_( |
| 48 | + mtp_startend_row_indices_tmp, |
| 49 | + mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids, |
| 50 | + self.eos_token_id, |
| 51 | + ) |
| 52 | + mtp_hidden_inputs_mask_tmp = self.sample_independent_mtp_hidden_inputs_mask_( |
| 53 | + mtp_hidden_inputs_mask_tmp, |
| 54 | + mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids, |
| 55 | + self.eos_token_id, |
| 56 | + ) |
| 57 | + elif mtp_ids_input_ids.shape[0] > 0 and task_id in [0] and self.inbatch_sft: |
| 58 | + # 样本互相不可见 |
| 59 | + assert self.document_mask_prob_text is None, f"{self.document_mask_prob_text}" |
| 60 | + sample_independent_startend_row_indices_( |
| 61 | + mtp_startend_row_indices_tmp, |
| 62 | + mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids, |
| 63 | + self.eos_token_id, |
| 64 | + ) |
| 65 | + mtp_hidden_inputs_mask_tmp = self.sample_independent_mtp_hidden_inputs_mask_( |
| 66 | + mtp_hidden_inputs_mask_tmp, |
| 67 | + mtp_ids_input_ids[:-1] if len(mtp_ids_input_ids) == self.seqlen else mtp_ids_input_ids, |
| 68 | + self.eos_token_id, |
| 69 | + ) |
| 70 | + |
| 71 | + mtp_startend_row_indices_all.append(deepcopy(mtp_startend_row_indices_tmp)) |
| 72 | + mtp_hidden_inputs_mask_all.append(deepcopy(mtp_hidden_inputs_mask_tmp)) |
| 73 | + if len(mtp_startend_row_indices_all) > 0: |
| 74 | + mtp_startend_row_indices_all = np.concatenate(mtp_startend_row_indices_all, axis=0) |
| 75 | + mtp_hidden_inputs_mask_all = np.concatenate(mtp_hidden_inputs_mask_all, axis=0) |
| 76 | + else: |
| 77 | + mtp_startend_row_indices_all = None |
| 78 | + mtp_hidden_inputs_mask_all = None |
| 79 | + |
| 80 | + return mtp_startend_row_indices_all, mtp_hidden_inputs_mask_all |
| 81 | + |
| 82 | +# 得到的"mtp_startend_row_indices_all", "mtp_hidden_inputs_mask_all" 也需要进行padding: |
| 83 | +def pad_mtp_hidden_inputs_mask(mtp_hidden_inputs_mask_all, max_seq_len): |
| 84 | + """pad_mtp_hidden_inputs_mask""" |
| 85 | + head, l = mtp_hidden_inputs_mask_all.shape |
| 86 | + assert head == 1, "预训练只训练了mtp=1的情况" |
| 87 | + if l == max_seq_len: |
| 88 | + return mtp_hidden_inputs_mask_all |
| 89 | + elif l < max_seq_len: |
| 90 | + pad_l = max_seq_len - l |
| 91 | + padding = np.ones((head, pad_l), dtype=mtp_hidden_inputs_mask_all.dtype) |
| 92 | + return np.concatenate([mtp_hidden_inputs_mask_all, padding], axis=1) |
| 93 | + else: |
| 94 | + raise Exception |
| 95 | + |
| 96 | + |
| 97 | +# =================================================================== |
| 98 | +# collate.py 对应翻译版(独立函数,无 self,对齐 collate.py 风格) |
| 99 | +# =================================================================== |
| 100 | +# |
| 101 | +# 变量对应关系: |
| 102 | +# learn.py (self.xxx) collate.py 参数 |
| 103 | +# --------------------------------------------------- |
| 104 | +# self.multi_token_pred_depth <-> mtp_depth |
| 105 | +# self.seqlen - 1 <-> total_len(实际 token 数,padding 前) |
| 106 | +# max_seq_len(含 MTP 扩展) <-> max_seq_len(已 += mtp_depth) |
| 107 | +# self.eos_token_id <-> eos_token_id |
| 108 | +# self.inbatch_sft / document <-> not use_global_causal_attn |
| 109 | +# ids(单条序列) <-> batch_token_ids(多条打包序列的列表) |
| 110 | +# |
| 111 | +# 关键差异: |
| 112 | +# learn.py 通过 EOS token 检测文档边界; |
| 113 | +# collate.py 的 packing 中文档边界已由 batch_token_ids 的分组隐式给出, |
| 114 | +# use_global_causal_attn=False 时直接按分组做分块 causal mask, |
| 115 | +# 无需再扫描 EOS。gen_mtp_layer_mask 仍保留 eos_token_id 参数以兼容 |
| 116 | +# 有 EOS 显式标记的场景。 |
| 117 | +# =================================================================== |
| 118 | + |
| 119 | +import numpy as np |
| 120 | +from typing import List |
| 121 | + |
| 122 | + |
| 123 | +def gen_mtp_attn_mask( |
| 124 | + batch_token_ids: List[List[int]], |
| 125 | + max_seq_len: int, |
| 126 | + mtp_depth: int, |
| 127 | + use_global_causal_attn: bool, |
| 128 | +) -> np.ndarray: |
| 129 | + """生成 MTP 每一层的 attention mask(二维矩阵形式)。 |
| 130 | +
|
| 131 | + 核心逻辑:每层 MTP 的输入序列相对原始序列右移了 (mtp_idx+1) 步, |
| 132 | + 因此文档块边界也随之左移 (mtp_idx+1),每层的 mask 矩阵不同。 |
| 133 | +
|
| 134 | + 对应 learn.py get_mtp_inputs_info 中 mtp_startend_row_indices 的构造逻辑: |
| 135 | + - 初始化为全局因果 mask(end_row = seqlen-1) |
| 136 | + - 若启用文档隔离,调用 sample_independent_startend_row_indices_ |
| 137 | + 以 mtp_ids_input_ids(移位序列)的 EOS 重新划分块边界 |
| 138 | +
|
| 139 | + Args: |
| 140 | + batch_token_ids: 打包的多个 sequence 的 token ids 列表。 |
| 141 | + collate.py 中文档边界由分组隐式给出,无需扫描 EOS。 |
| 142 | + max_seq_len: padding 后序列长度,已含 mtp_depth 扩展。 |
| 143 | + mtp_depth: MTP 预测层数 D。对应 self.multi_token_pred_depth。 |
| 144 | + use_global_causal_attn: True 时全局因果(单块); |
| 145 | + False 时分块因果,块边界随层数左移。 |
| 146 | +
|
| 147 | + Returns: |
| 148 | + np.ndarray, shape [mtp_depth, 1, max_seq_len, max_seq_len], dtype=float32。 |
| 149 | + 与 gen_self_attn_mask 返回的 [1, 1, max_seq_len, max_seq_len] 格式一致, |
| 150 | + 第0维扩展为 mtp_depth。 |
| 151 | + """ |
| 152 | + total_len = sum(len(ids) for ids in batch_token_ids) |
| 153 | + |
| 154 | + # 原始文档块边界(exclusive),不含最后一个 total_len |
| 155 | + # 例:batch_token_ids=[[A,B,EOS],[D,E,F]] → internal_boundaries=[3] |
| 156 | + internal_boundaries = [] |
| 157 | + offset = 0 |
| 158 | + for ids in batch_token_ids[:-1]: |
| 159 | + offset += len(ids) |
| 160 | + internal_boundaries.append(offset) |
| 161 | + |
| 162 | + result = [] |
| 163 | + for mtp_idx in range(mtp_depth): |
| 164 | + mask = np.zeros((max_seq_len, max_seq_len), dtype=np.float32) |
| 165 | + |
| 166 | + if use_global_causal_attn: |
| 167 | + # 全局因果:整个序列是一个块 |
| 168 | + b = np.tril(np.ones([total_len, total_len])) |
| 169 | + mask[:total_len, :total_len] = b |
| 170 | + else: |
| 171 | + # 分块因果,块边界左移 (mtp_idx+1) |
| 172 | + # 原理:MTP层k的输入是原始序列右移k+1步, |
| 173 | + # 原来在位置 b 的文档边界,在移位视角下出现在位置 b-(k+1) |
| 174 | + shift = mtp_idx + 1 |
| 175 | + shifted_boundaries = [b - shift for b in internal_boundaries if b - shift > 0] |
| 176 | + # 最后一个边界始终是 total_len(最后文档块延伸到序列末尾) |
| 177 | + all_boundaries = shifted_boundaries + [total_len] |
| 178 | + |
| 179 | + prev = 0 |
| 180 | + for boundary in all_boundaries: |
| 181 | + cur_len = boundary - prev |
| 182 | + if cur_len > 0: |
| 183 | + # 在 [prev:boundary, prev:boundary] 填下三角因果矩阵 |
| 184 | + mask[prev:boundary, prev:boundary] = np.tril(np.ones([cur_len, cur_len])) |
| 185 | + prev = boundary |
| 186 | + |
| 187 | + result.append(mask) |
| 188 | + |
| 189 | + # [mtp_depth, max_seq_len, max_seq_len] → [mtp_depth, 1, max_seq_len, max_seq_len] |
| 190 | + return np.stack(result, axis=0)[:, None, :, :] |
| 191 | + |
| 192 | + |
| 193 | +def gen_mtp_attn_mask_startend_row_indices( |
| 194 | + batch_token_ids: List[List[int]], |
| 195 | + max_seq_len: int, |
| 196 | + mtp_depth: int, |
| 197 | + use_global_causal_attn: bool, |
| 198 | +) -> np.ndarray: |
| 199 | + """生成 MTP 每一层的 attention mask(压缩一维 startend_row_indices 格式)。 |
| 200 | +
|
| 201 | + 是 gen_mtp_attn_mask 的压缩版本,与 gen_attn_mask_startend_row_indices 格式对齐。 |
| 202 | + 每个位置存储其所在块的 end_row(exclusive),由 flash attention 内核推断因果关系。 |
| 203 | +
|
| 204 | + 块边界规律(与 gen_mtp_attn_mask 相同): |
| 205 | + 层k 的块边界 = 原始边界 - (k+1) |
| 206 | +
|
| 207 | + Args: |
| 208 | + 同 gen_mtp_attn_mask。 |
| 209 | +
|
| 210 | + Returns: |
| 211 | + np.ndarray, shape [mtp_depth, 1, max_seq_len, 1], dtype=int32。 |
| 212 | + 与 gen_attn_mask_startend_row_indices 返回的 [1, 1, max_seq_len, 1] 格式一致, |
| 213 | + 第0维扩展为 mtp_depth。 |
| 214 | + """ |
| 215 | + total_len = sum(len(ids) for ids in batch_token_ids) |
| 216 | + |
| 217 | + internal_boundaries = [] |
| 218 | + offset = 0 |
| 219 | + for ids in batch_token_ids[:-1]: |
| 220 | + offset += len(ids) |
| 221 | + internal_boundaries.append(offset) |
| 222 | + |
| 223 | + result = [] |
| 224 | + for mtp_idx in range(mtp_depth): |
| 225 | + if use_global_causal_attn: |
| 226 | + indices = [total_len] * total_len |
| 227 | + else: |
| 228 | + shift = mtp_idx + 1 |
| 229 | + shifted_boundaries = [b - shift for b in internal_boundaries if b - shift > 0] |
| 230 | + all_boundaries = shifted_boundaries + [total_len] |
| 231 | + |
| 232 | + indices = [] |
| 233 | + prev = 0 |
| 234 | + for boundary in all_boundaries: |
| 235 | + cur_len = boundary - prev |
| 236 | + # 该块内所有位置的 end_row 均为 boundary(块的 exclusive 末尾) |
| 237 | + indices.extend([boundary] * cur_len) |
| 238 | + prev = boundary |
| 239 | + |
| 240 | + # padding 区域:与 gen_attn_mask_startend_row_indices 保持一致 |
| 241 | + # range(total_len, max_seq_len) 让 padding 位置的 end_row 递增 |
| 242 | + if total_len < max_seq_len: |
| 243 | + indices.extend(list(range(total_len, max_seq_len))) |
| 244 | + |
| 245 | + result.append(indices) |
| 246 | + |
| 247 | + # [mtp_depth, max_seq_len] → [mtp_depth, 1, max_seq_len, 1] |
| 248 | + return np.array(result, dtype=np.int32)[:, None, :, None] |
| 249 | + |
| 250 | + |
| 251 | +def gen_mtp_layer_mask( |
| 252 | + batch_token_ids: List[List[int]], |
| 253 | + max_seq_len: int, |
| 254 | + mtp_depth: int, |
| 255 | + eos_token_id: int = None, |
| 256 | +) -> np.ndarray: |
| 257 | + """生成 MTP 每一层的 hidden inputs mask(待后续讨论,暂保留原有实现)。""" |
| 258 | + all_token_ids = np.concatenate([np.array(ids, dtype=np.int32) for ids in batch_token_ids]) |
| 259 | + total_len = len(all_token_ids) |
| 260 | + ids_mtp = all_token_ids[-mtp_depth:] |
| 261 | + ids_ori = all_token_ids[:-mtp_depth] |
| 262 | + |
| 263 | + result = [] |
| 264 | + for mtp_idx in range(mtp_depth): |
| 265 | + mask = np.ones(total_len, dtype=np.int32) |
| 266 | + if eos_token_id is not None: |
| 267 | + mtp_ids_input_ids = np.concatenate([ids_ori[mtp_idx + 1:], ids_mtp[:mtp_idx + 1]]) |
| 268 | + eos_positions = np.where(mtp_ids_input_ids == eos_token_id)[0] |
| 269 | + mask[eos_positions] = 0 |
| 270 | + if total_len < max_seq_len: |
| 271 | + mask = np.concatenate([mask, np.ones(max_seq_len - total_len, dtype=np.int32)]) |
| 272 | + result.append(mask) |
| 273 | + |
| 274 | + return np.stack(result, axis=0) # [mtp_depth, max_seq_len] |
0 commit comments