77class GPT2LLMCollateFn (CollateFnIF ):
88 """GPT2LLMCollateFn class to define a collate function for GPT2 language model."""
99
10- def __init__ (self , sample_key : str , target_key : str ):
10+ def __init__ (
11+ self ,
12+ sample_key : str ,
13+ target_key : str ,
14+ sub_seq_lengths_key : str | None = None ,
15+ eos_token_id : int | None = None ,
16+ padding_token_id : int | None = None ,
17+ ):
1118 """
1219 Initializes the Collator object.
20+ If the eos token ID and the sub_seq_lengths_key are provided,
21+ a list[list[int]] representing the sub-sequence lengths will be created.
1322
1423 Args:
1524 sample_key (str): The key for accessing the sample data.
1625 target_key (str): The key for accessing the target data.
26+ sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
27+ eos_token_id (int | None): The end-of-sequence token ID.
28+ padding_token_id (int | None): The padding token ID.
1729 """
1830 self .sample_key = sample_key
1931 self .target_key = target_key
32+ self .sub_seq_lengths_key = sub_seq_lengths_key
33+ self .eos_token_id = eos_token_id
34+ self .padding_token_id = padding_token_id
2035
2136 def __call__ (self , batch : list [dict [str , torch .Tensor ]]) -> DatasetBatch :
2237 """
@@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
3348 sample_tensor = torch .stack ([torch .tensor (d [self .sample_key ]) for d in batch ])
3449 samples = {self .sample_key : sample_tensor [:, :- 1 ]}
3550 targets = {self .target_key : sample_tensor [:, 1 :]}
51+ if self .sub_seq_lengths_key is not None :
52+ # Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
53+ sub_seq_lengths = self ._compute_sub_sequence_lengths_for_each_sequence (samples [self .sample_key ])
54+ samples [self .sub_seq_lengths_key ] = sub_seq_lengths
3655 return DatasetBatch (targets = targets , samples = samples )
56+
57+ def _compute_sub_sequence_lengths_for_each_sequence (self , sample_tensor : torch .Tensor ) -> list [list [int ]]:
58+ sub_seq_lengths_in_batch = []
59+ for batch_seq in sample_tensor :
60+ eos_positions = (batch_seq == self .eos_token_id ).nonzero (as_tuple = True )[0 ]
61+ if len (eos_positions ) == 0 :
62+ assert self .padding_token_id is None or (
63+ batch_seq [0 ] != self .padding_token_id and torch .all (batch_seq != self .padding_token_id )
64+ ), "Whole batch sequence consists of padding tokens."
65+ sub_seq_lengths_in_batch .append ([len (batch_seq )])
66+ else :
67+ lens_in_seq = self ._compute_subsequence_length_in_sequence (batch_seq , eos_positions )
68+ sub_seq_lengths_in_batch .append (lens_in_seq )
69+ return sub_seq_lengths_in_batch
70+
71+ def _compute_subsequence_length_in_sequence (self , seq : torch .Tensor , eos_positions : torch .Tensor ) -> list [int ]:
72+ # If the last sequence is cut, i.e. does not end on an eos token,
73+ # it should also be included unless the padding token is set and
74+ # the last sequence is just padding.
75+ last_eos_pos = eos_positions [- 1 ].item ()
76+ if self ._has_cutoff_final_sequence (seq , last_eos_pos ):
77+ eos_positions = torch .cat ([eos_positions , eos_positions .new_tensor ([len (seq ) - 1 ])])
78+ # Compute length of each subsequence and add to lengths list.
79+ sub_seq_lengths = []
80+ prev_pos = 0
81+ for pos in eos_positions :
82+ sub_seq_lengths .append (pos .item () - prev_pos + 1 )
83+ prev_pos = pos .item () + 1
84+ return sub_seq_lengths
85+
86+ def _has_cutoff_final_sequence (self , seq : torch .Tensor , last_eos_pos : int ) -> bool :
87+ # Assumption: If the first token of the last sequence is padding, so is the rest.
88+ return last_eos_pos < len (seq ) - 1 and (
89+ self .padding_token_id is None or seq [last_eos_pos + 1 ] != self .padding_token_id
90+ )
0 commit comments