Skip to content

Commit d9e35ee

Browse files
committed
chore: Merge branch 'inter_document_masking_for_attention' into loss_filtering
2 parents 864eae2 + 956b958 commit d9e35ee

6 files changed

Lines changed: 1045 additions & 47 deletions

File tree

src/modalities/config/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,21 @@ class BatchSamplerConfig(BaseModel):
461461
class GPT2LLMCollateFnConfig(BaseModel):
462462
sample_key: str
463463
target_key: str
464+
sub_seq_lengths_key: str | None = None
465+
eos_token_id: int | None = None
466+
padding_token_id: int | None = None
467+
468+
@model_validator(mode="after")
469+
def check_sub_seq_lengths_and_eos_token(self) -> "GPT2LLMCollateFnConfig":
470+
if (self.sub_seq_lengths_key is None) != (self.eos_token_id is None):
471+
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
472+
return self
473+
474+
@model_validator(mode="after")
475+
def check_padding_token_and_sub_seq_lengths(self) -> "GPT2LLMCollateFnConfig":
476+
if self.padding_token_id is not None and self.sub_seq_lengths_key is None:
477+
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
478+
return self
464479

465480

466481
class LLMDataLoaderConfig(BaseModel):

src/modalities/models/gpt2/collator.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,31 @@
77
class GPT2LLMCollateFn(CollateFnIF):
88
"""GPT2LLMCollateFn class to define a collate function for GPT2 language model."""
99

10-
def __init__(self, sample_key: str, target_key: str):
10+
def __init__(
11+
self,
12+
sample_key: str,
13+
target_key: str,
14+
sub_seq_lengths_key: str | None = None,
15+
eos_token_id: int | None = None,
16+
padding_token_id: int | None = None,
17+
):
1118
"""
1219
Initializes the Collator object.
20+
If the eos token ID and the sub_seq_lengths_key are provided,
21+
a list[list[int]] representing the sub-sequence lengths will be created.
1322
1423
Args:
1524
sample_key (str): The key for accessing the sample data.
1625
target_key (str): The key for accessing the target data.
26+
sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
27+
eos_token_id (int | None): The end-of-sequence token ID.
28+
padding_token_id (int | None): The padding token ID.
1729
"""
1830
self.sample_key = sample_key
1931
self.target_key = target_key
32+
self.sub_seq_lengths_key = sub_seq_lengths_key
33+
self.eos_token_id = eos_token_id
34+
self.padding_token_id = padding_token_id
2035

2136
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
2237
"""
@@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
3348
sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch])
3449
samples = {self.sample_key: sample_tensor[:, :-1]}
3550
targets = {self.target_key: sample_tensor[:, 1:]}
51+
if self.sub_seq_lengths_key is not None:
52+
# Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
53+
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key])
54+
samples[self.sub_seq_lengths_key] = sub_seq_lengths
3655
return DatasetBatch(targets=targets, samples=samples)
56+
57+
def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]:
58+
sub_seq_lengths_in_batch = []
59+
for batch_seq in sample_tensor:
60+
eos_positions = (batch_seq == self.eos_token_id).nonzero(as_tuple=True)[0]
61+
if len(eos_positions) == 0:
62+
assert self.padding_token_id is None or (
63+
batch_seq[0] != self.padding_token_id and torch.all(batch_seq != self.padding_token_id)
64+
), "Whole batch sequence consists of padding tokens."
65+
sub_seq_lengths_in_batch.append([len(batch_seq)])
66+
else:
67+
lens_in_seq = self._compute_subsequence_length_in_sequence(batch_seq, eos_positions)
68+
sub_seq_lengths_in_batch.append(lens_in_seq)
69+
return sub_seq_lengths_in_batch
70+
71+
def _compute_subsequence_length_in_sequence(self, seq: torch.Tensor, eos_positions: torch.Tensor) -> list[int]:
72+
# If the last sequence is cut, i.e. does not end on an eos token,
73+
# it should also be included unless the padding token is set and
74+
# the last sequence is just padding.
75+
last_eos_pos = eos_positions[-1].item()
76+
if self._has_cutoff_final_sequence(seq, last_eos_pos):
77+
eos_positions = torch.cat([eos_positions, eos_positions.new_tensor([len(seq) - 1])])
78+
# Compute length of each subsequence and add to lengths list.
79+
sub_seq_lengths = []
80+
prev_pos = 0
81+
for pos in eos_positions:
82+
sub_seq_lengths.append(pos.item() - prev_pos + 1)
83+
prev_pos = pos.item() + 1
84+
return sub_seq_lengths
85+
86+
def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eos_pos: int) -> bool:
87+
# Assumption: If the first token of the last sequence is padding, so is the rest.
88+
return last_eos_pos < len(seq) - 1 and (
89+
self.padding_token_id is None or seq[last_eos_pos + 1] != self.padding_token_id
90+
)

0 commit comments

Comments
 (0)