@@ -482,6 +482,11 @@ def collate_fn(
482482 max_seq_len = calc_padding_size (max_seq_len , training_args )
483483 if training_args .num_nextn_predict_layers > 0 :
484484 max_seq_len += training_args .num_nextn_predict_layers
485+ if model_args .use_attn_mask_startend_row_indices :
486+ input_keys .append ("mtp_attn_mask_startend_row_indices" )
487+ else :
488+ input_keys .append ("mtp_attn_mask" )
489+ input_keys .append ("mtp_layer_mask" )
485490
486491 for batch_sequence in batch :
487492 if len (batch_sequence ) == 1 and isinstance (batch_sequence [0 ].position_ids [0 ], List ):
@@ -525,6 +530,36 @@ def collate_fn(
525530 gen_self_attn_mask (original_position_ids , max_seq_len , model_args .use_global_causal_attn )
526531 )
527532
533+ if training_args .num_nextn_predict_layers > 0 :
534+
535+ if model_args .use_attn_mask_startend_row_indices :
536+ return_list [- 1 ].append (
537+ gen_mtp_attn_mask_startend_row_indices (
538+ original_position_ids ,
539+ max_seq_len ,
540+ training_args .num_nextn_predict_layers ,
541+ model_args .use_global_causal_attn ,
542+ )
543+ )
544+ else :
545+ return_list [- 1 ].append (
546+ gen_mtp_attn_mask (
547+ original_position_ids ,
548+ max_seq_len ,
549+ training_args .num_nextn_predict_layers ,
550+ model_args .use_global_causal_attn ,
551+ )
552+ )
553+
554+ return_list [- 1 ].append (
555+ gen_mtp_layer_mask (
556+ original_position_ids ,
557+ max_seq_len ,
558+ training_args .num_nextn_predict_layers ,
559+ tokenizer .eos_token_id ,
560+ )
561+ )
562+
528563 return_list = [np .concatenate (tensor_list ) for tensor_list in zip (* return_list )]
529564 input_dict = dict (zip (input_keys , return_list ))
530565 return input_dict
@@ -851,3 +886,118 @@ def gen_attn_mask_startend_row_indices(
851886 attn_mask_startend_row_indices .extend (list (range (offset , max_seq_len )))
852887 # NOTE(hehuang): The dtype of attn_mask_startend_row_indices must be np.int32
853888 return np .array (attn_mask_startend_row_indices , dtype = np .int32 )[None , None , ..., None ] # add dimension modify
889+
890+
891+ def gen_mtp_attn_mask (
892+ batch_token_ids : List [List [int ]],
893+ max_seq_len : int ,
894+ mtp_depth : int ,
895+ use_global_causal_attn : bool ,
896+ ) -> np .ndarray :
897+ """Generate MTP per-layer attention mask (2D matrix form).
898+
899+ Args:
900+ batch_token_ids: List of token ID sequences (document grouping provides boundaries).
901+ max_seq_len: Padded sequence length, already extended by mtp_depth.
902+ mtp_depth: Number of MTP prediction layers D.
903+ use_global_causal_attn: If True, use global causal mask (single block);
904+ otherwise use block-causal mask with per-layer shifted boundaries.
905+
906+ Returns:
907+ np.ndarray, shape [mtp_depth, 1, max_seq_len, max_seq_len], dtype=float32.
908+ """
909+ total_len = sum (len (ids ) for ids in batch_token_ids )
910+ if use_global_causal_attn :
911+ single = np .zeros ((max_seq_len , max_seq_len ), dtype = np .float32 )
912+ single [:total_len , :total_len ] = np .tril (np .ones ([total_len , total_len ]))
913+ result = np .stack ([single ] * mtp_depth , axis = 0 )
914+ else :
915+ internal_boundaries = []
916+ offset = 0
917+ for ids in batch_token_ids [:- 1 ]:
918+ offset += len (ids )
919+ internal_boundaries .append (offset )
920+ result = []
921+ for mtp_idx in range (mtp_depth ):
922+ mask = np .zeros ((max_seq_len , max_seq_len ), dtype = np .float32 )
923+ shift = mtp_idx + 1
924+ all_boundaries = [b - shift for b in internal_boundaries if b - shift > 0 ] + [total_len ]
925+ prev = 0
926+ for boundary in all_boundaries :
927+ if boundary > prev :
928+ mask [prev :boundary , prev :boundary ] = np .tril (np .ones ([boundary - prev , boundary - prev ]))
929+ prev = boundary
930+ result .append (mask )
931+ result = np .stack (result , axis = 0 )
932+ return result [:, None , :, :]
933+
934+
935+ def gen_mtp_attn_mask_startend_row_indices (
936+ batch_token_ids : List [List [int ]],
937+ max_seq_len : int ,
938+ mtp_depth : int ,
939+ use_global_causal_attn : bool ,
940+ ) -> np .ndarray :
941+ """Generate MTP per-layer attention mask (compressed startend_row_indices form).
942+
943+ Args:
944+ batch_token_ids: List of token ID sequences.
945+ max_seq_len: Padded sequence length, already extended by mtp_depth.
946+ mtp_depth: Number of MTP prediction layers D.
947+ use_global_causal_attn: If True, single global block; otherwise per-layer shifted blocks.
948+
949+ Returns:
950+ np.ndarray, shape [mtp_depth, 1, max_seq_len, 1], dtype=int32.
951+ """
952+ total_len = sum (len (ids ) for ids in batch_token_ids )
953+ pad_indices = list (range (total_len , max_seq_len ))
954+ if use_global_causal_attn :
955+ row = [total_len ] * total_len + pad_indices
956+ result = np .array ([row ] * mtp_depth , dtype = np .int32 )
957+ else :
958+ internal_boundaries = []
959+ offset = 0
960+ for ids in batch_token_ids [:- 1 ]:
961+ offset += len (ids )
962+ internal_boundaries .append (offset )
963+ result = []
964+ for mtp_idx in range (mtp_depth ):
965+ shift = mtp_idx + 1
966+ all_boundaries = [b - shift for b in internal_boundaries if b - shift > 0 ] + [total_len ]
967+ indices = []
968+ prev = 0
969+ for boundary in all_boundaries :
970+ indices .extend ([boundary ] * (boundary - prev ))
971+ prev = boundary
972+ result .append (indices + pad_indices )
973+ result = np .array (result , dtype = np .int32 )
974+ return result [:, None , :, None ]
975+
976+
977+ def gen_mtp_layer_mask (
978+ batch_token_ids : List [List [int ]],
979+ max_seq_len : int ,
980+ mtp_depth : int ,
981+ eos_token_id : int = None ,
982+ ) -> np .ndarray :
983+ """Generate MTP per-layer hidden inputs mask.
984+
985+ Args:
986+ batch_token_ids: List of token ID sequences.
987+ max_seq_len: Padded sequence length, already extended by mtp_depth.
988+ mtp_depth: Number of MTP prediction layers D.
989+ eos_token_id: If provided, zero out positions where EOS appears in shifted input.
990+
991+ Returns:
992+ np.ndarray, shape [mtp_depth, max_seq_len], dtype=int32.
993+ """
994+ if eos_token_id is None :
995+ return np .ones ((mtp_depth , max_seq_len ), dtype = np .int32 )
996+ all_token_ids = np .concatenate ([np .array (ids , dtype = np .int32 ) for ids in batch_token_ids ])
997+ result = []
998+ for mtp_idx in range (mtp_depth ):
999+ mask = np .ones (max_seq_len , dtype = np .int32 )
1000+ shifted = all_token_ids [mtp_idx + 1 :]
1001+ mask [np .where (shifted == eos_token_id )[0 ]] = 0
1002+ result .append (mask )
1003+ return np .stack (result , axis = 0 )
0 commit comments