Skip to content

Commit 9cf4088

Browse files
Merge pull request #2970 from AI-Hypercomputer:support-system-role-in-hf-input
PiperOrigin-RevId: 862960371
2 parents 8c4d0b0 + 095c53a commit 9cf4088

3 files changed

Lines changed: 297 additions & 128 deletions

File tree

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@
3131
from MaxText import multihost_dataloading
3232

3333

34+
def _get_pad_id(tokenizer):
35+
if tokenizer.pad_token_id is not None:
36+
pad_id = tokenizer.pad_token_id
37+
elif tokenizer.unk_token_id is not None:
38+
pad_id = tokenizer.unk_token_id
39+
else:
40+
pad_id = -1
41+
return pad_id
42+
43+
3444
def vision_sft_preprocessing_pipeline(
3545
dataset,
3646
config,
@@ -89,12 +99,7 @@ def vision_sft_preprocessing_pipeline(
8999
legacy=False,
90100
token=config.hf_access_token,
91101
)
92-
if tokenizer.pad_token_id is not None:
93-
pad_id = tokenizer.pad_token_id
94-
elif tokenizer.unk_token_id is not None:
95-
pad_id = tokenizer.unk_token_id
96-
else:
97-
pad_id = -1
102+
pad_id = _get_pad_id(tokenizer)
98103

99104
dataset = dataset.map(
100105
_input_pipeline_utils.tokenization,
@@ -246,12 +251,7 @@ def preprocessing_pipeline(
246251
else:
247252
dataset = dataset.select_columns(data_column_names)
248253

249-
if tokenizer.pad_token_id is not None:
250-
pad_id = tokenizer.pad_token_id
251-
elif tokenizer.unk_token_id is not None:
252-
pad_id = tokenizer.unk_token_id
253-
else:
254-
pad_id = -1
254+
pad_id = _get_pad_id(tokenizer)
255255

256256
if tokenize:
257257
dataset = dataset.map(

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,30 +170,37 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
170170
- The `data_column_name` column will be updated to a list of
171171
messages, each formatted according to the tokenizer's chat template.
172172
- A new column named "is_prompt" will be added, where `True`
173-
indicates a user message (prompt) and `False` indicates an assistant
173+
indicates a system message or a user message (prompt) and `False` indicates an assistant
174174
message (completion).
175175
"""
176176
messages = []
177177
is_prompt = []
178-
prompt = None
178+
round_msgs = []
179179
try:
180-
for message in example[data_column_name]:
181-
if message["role"] == "user":
182-
prompt = message
180+
for idx, message in enumerate(example[data_column_name]):
181+
if message["role"] == "system":
182+
if idx != 0:
183+
raise ValueError(f"System message found at index {idx}. System messages must be at index 0.")
184+
round_msgs.append(message)
185+
elif message["role"] == "user":
186+
round_msgs.append(message)
183187
prompt_in_chat_template = tokenizer_model.apply_chat_template(
184-
[prompt], add_generation_prompt=False, tokenize=False
188+
round_msgs, add_generation_prompt=False, tokenize=False
185189
)
186190
messages.append(prompt_in_chat_template)
187191
is_prompt.append(True)
188192
elif message["role"] == "assistant":
193+
round_msgs.append(message)
189194
prompt_completion_tokens = tokenizer_model.apply_chat_template(
190-
[prompt, message], add_generation_prompt=False, tokenize=True
195+
round_msgs, add_generation_prompt=False, tokenize=True
191196
)
192-
prompt_tokens = tokenizer_model.apply_chat_template([prompt], add_generation_prompt=False, tokenize=True)
197+
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=False, tokenize=True)
193198
completion_tokens = prompt_completion_tokens[len(prompt_tokens) :]
194199
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
195200
messages.append(completion_in_chat_template)
196201
is_prompt.append(False)
202+
# Round ended, clearing the buffer.
203+
round_msgs.clear()
197204
except ValueError as e:
198205
max_logging.log(f"Unable to apply chat template: {e}")
199206
raise e
@@ -688,6 +695,7 @@ def shift_left(x, pad_id, axis=1):
688695
def shift_and_refine(x, ignored_ids, axis=1):
689696
"""Shift inputs, set segmentation to 0 when target element is in ignored_ids if provided"""
690697
x["targets"] = shift_left(x["targets"], ignored_ids[0], axis=axis)
698+
x["targets_segmentation"] = shift_left(x["targets_segmentation"], 0, axis=axis)
691699
for ignore_id in ignored_ids:
692700
x["targets_segmentation"] = np.where(x["targets"] != ignore_id, x["targets_segmentation"], 0)
693701

0 commit comments

Comments
 (0)