Skip to content

Commit c04ba85

Browse files
authored
add mtp attn mask and layer mask (#4343)
1 parent 7a3207a commit c04ba85

2 files changed

Lines changed: 315 additions & 95 deletions

File tree

paddleformers/datasets/collate.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,11 @@ def collate_fn(
482482
max_seq_len = calc_padding_size(max_seq_len, training_args)
483483
if training_args.num_nextn_predict_layers > 0:
484484
max_seq_len += training_args.num_nextn_predict_layers
485+
if model_args.use_attn_mask_startend_row_indices:
486+
input_keys.append("mtp_attn_mask_startend_row_indices")
487+
else:
488+
input_keys.append("mtp_attn_mask")
489+
input_keys.append("mtp_layer_mask")
485490

486491
for batch_sequence in batch:
487492
if len(batch_sequence) == 1 and isinstance(batch_sequence[0].position_ids[0], List):
@@ -525,6 +530,36 @@ def collate_fn(
525530
gen_self_attn_mask(original_position_ids, max_seq_len, model_args.use_global_causal_attn)
526531
)
527532

533+
if training_args.num_nextn_predict_layers > 0:
534+
535+
if model_args.use_attn_mask_startend_row_indices:
536+
return_list[-1].append(
537+
gen_mtp_attn_mask_startend_row_indices(
538+
original_position_ids,
539+
max_seq_len,
540+
training_args.num_nextn_predict_layers,
541+
model_args.use_global_causal_attn,
542+
)
543+
)
544+
else:
545+
return_list[-1].append(
546+
gen_mtp_attn_mask(
547+
original_position_ids,
548+
max_seq_len,
549+
training_args.num_nextn_predict_layers,
550+
model_args.use_global_causal_attn,
551+
)
552+
)
553+
554+
return_list[-1].append(
555+
gen_mtp_layer_mask(
556+
original_position_ids,
557+
max_seq_len,
558+
training_args.num_nextn_predict_layers,
559+
tokenizer.eos_token_id,
560+
)
561+
)
562+
528563
return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)]
529564
input_dict = dict(zip(input_keys, return_list))
530565
return input_dict
@@ -851,3 +886,118 @@ def gen_attn_mask_startend_row_indices(
851886
attn_mask_startend_row_indices.extend(list(range(offset, max_seq_len)))
852887
# NOTE(hehuang): The dtype of attn_mask_startend_row_indices must be np.int32
853888
return np.array(attn_mask_startend_row_indices, dtype=np.int32)[None, None, ..., None] # add dimension modify
889+
890+
891+
def gen_mtp_attn_mask(
892+
batch_token_ids: List[List[int]],
893+
max_seq_len: int,
894+
mtp_depth: int,
895+
use_global_causal_attn: bool,
896+
) -> np.ndarray:
897+
"""Generate MTP per-layer attention mask (2D matrix form).
898+
899+
Args:
900+
batch_token_ids: List of token ID sequences (document grouping provides boundaries).
901+
max_seq_len: Padded sequence length, already extended by mtp_depth.
902+
mtp_depth: Number of MTP prediction layers D.
903+
use_global_causal_attn: If True, use global causal mask (single block);
904+
otherwise use block-causal mask with per-layer shifted boundaries.
905+
906+
Returns:
907+
np.ndarray, shape [mtp_depth, 1, max_seq_len, max_seq_len], dtype=float32.
908+
"""
909+
total_len = sum(len(ids) for ids in batch_token_ids)
910+
if use_global_causal_attn:
911+
single = np.zeros((max_seq_len, max_seq_len), dtype=np.float32)
912+
single[:total_len, :total_len] = np.tril(np.ones([total_len, total_len]))
913+
result = np.stack([single] * mtp_depth, axis=0)
914+
else:
915+
internal_boundaries = []
916+
offset = 0
917+
for ids in batch_token_ids[:-1]:
918+
offset += len(ids)
919+
internal_boundaries.append(offset)
920+
result = []
921+
for mtp_idx in range(mtp_depth):
922+
mask = np.zeros((max_seq_len, max_seq_len), dtype=np.float32)
923+
shift = mtp_idx + 1
924+
all_boundaries = [b - shift for b in internal_boundaries if b - shift > 0] + [total_len]
925+
prev = 0
926+
for boundary in all_boundaries:
927+
if boundary > prev:
928+
mask[prev:boundary, prev:boundary] = np.tril(np.ones([boundary - prev, boundary - prev]))
929+
prev = boundary
930+
result.append(mask)
931+
result = np.stack(result, axis=0)
932+
return result[:, None, :, :]
933+
934+
935+
def gen_mtp_attn_mask_startend_row_indices(
936+
batch_token_ids: List[List[int]],
937+
max_seq_len: int,
938+
mtp_depth: int,
939+
use_global_causal_attn: bool,
940+
) -> np.ndarray:
941+
"""Generate MTP per-layer attention mask (compressed startend_row_indices form).
942+
943+
Args:
944+
batch_token_ids: List of token ID sequences.
945+
max_seq_len: Padded sequence length, already extended by mtp_depth.
946+
mtp_depth: Number of MTP prediction layers D.
947+
use_global_causal_attn: If True, single global block; otherwise per-layer shifted blocks.
948+
949+
Returns:
950+
np.ndarray, shape [mtp_depth, 1, max_seq_len, 1], dtype=int32.
951+
"""
952+
total_len = sum(len(ids) for ids in batch_token_ids)
953+
pad_indices = list(range(total_len, max_seq_len))
954+
if use_global_causal_attn:
955+
row = [total_len] * total_len + pad_indices
956+
result = np.array([row] * mtp_depth, dtype=np.int32)
957+
else:
958+
internal_boundaries = []
959+
offset = 0
960+
for ids in batch_token_ids[:-1]:
961+
offset += len(ids)
962+
internal_boundaries.append(offset)
963+
result = []
964+
for mtp_idx in range(mtp_depth):
965+
shift = mtp_idx + 1
966+
all_boundaries = [b - shift for b in internal_boundaries if b - shift > 0] + [total_len]
967+
indices = []
968+
prev = 0
969+
for boundary in all_boundaries:
970+
indices.extend([boundary] * (boundary - prev))
971+
prev = boundary
972+
result.append(indices + pad_indices)
973+
result = np.array(result, dtype=np.int32)
974+
return result[:, None, :, None]
975+
976+
977+
def gen_mtp_layer_mask(
978+
batch_token_ids: List[List[int]],
979+
max_seq_len: int,
980+
mtp_depth: int,
981+
eos_token_id: int = None,
982+
) -> np.ndarray:
983+
"""Generate MTP per-layer hidden inputs mask.
984+
985+
Args:
986+
batch_token_ids: List of token ID sequences.
987+
max_seq_len: Padded sequence length, already extended by mtp_depth.
988+
mtp_depth: Number of MTP prediction layers D.
989+
eos_token_id: If provided, zero out positions where EOS appears in shifted input.
990+
991+
Returns:
992+
np.ndarray, shape [mtp_depth, max_seq_len], dtype=int32.
993+
"""
994+
if eos_token_id is None:
995+
return np.ones((mtp_depth, max_seq_len), dtype=np.int32)
996+
all_token_ids = np.concatenate([np.array(ids, dtype=np.int32) for ids in batch_token_ids])
997+
result = []
998+
for mtp_idx in range(mtp_depth):
999+
mask = np.ones(max_seq_len, dtype=np.int32)
1000+
shifted = all_token_ids[mtp_idx + 1 :]
1001+
mask[np.where(shifted == eos_token_id)[0]] = 0
1002+
result.append(mask)
1003+
return np.stack(result, axis=0)

0 commit comments

Comments
 (0)