Skip to content

Commit b5f41ec

Browse files
Merge pull request #3284 from AI-Hypercomputer:jimmytsai/fix-sft-masking
PiperOrigin-RevId: 877960517
2 parents 2e56b5d + 2a52097 commit b5f41ec

2 files changed

Lines changed: 46 additions & 22 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def is_conversational(features, data_columns):
173173

174174
def apply_chat_template(example, tokenizer_model, data_column_name):
175175
"""Formats conversational data by applying the tokenizer's chat template
176-
and identifying prompt/completion segments.
176+
and identifying prompt/completion segments for SFT masking.
177177
178178
Args:
179179
example: A dictionary containing conversational data. It is expected to have a key
@@ -187,9 +187,10 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
187187
The modified `example` dictionary.
188188
- The `data_column_name` column will be updated to a list of
189189
messages, each formatted according to the tokenizer's chat template.
190-
- A new column named "is_prompt" will be added, where `True`
191-
indicates a system message or a user message (prompt) and `False` indicates an assistant
192-
message (completion).
190+
- A new column "is_prompt" is added, where `True` indicates the
191+
tokens contain the system message, user message, and generation
192+
prompt (if applicable). `False` indicates the expected LLM
193+
completion, excluding the assistant's start tokens.
193194
"""
194195
messages = []
195196
is_prompt = []
@@ -203,7 +204,7 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
203204
elif message["role"] == "user":
204205
round_msgs.append(message)
205206
prompt_in_chat_template = tokenizer_model.apply_chat_template(
206-
round_msgs, add_generation_prompt=False, tokenize=False
207+
round_msgs, add_generation_prompt=True, tokenize=False
207208
)
208209
messages.append(prompt_in_chat_template)
209210
is_prompt.append(True)
@@ -212,7 +213,8 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
212213
prompt_completion_tokens = tokenizer_model.apply_chat_template(
213214
round_msgs, add_generation_prompt=False, tokenize=True
214215
)
215-
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=False, tokenize=True)
216+
# include generation_prompt as part of the prompt tokens
217+
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
216218
completion_tokens = prompt_completion_tokens[len(prompt_tokens) :]
217219
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
218220
messages.append(completion_in_chat_template)

tests/unit/sft_data_processing_test.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,20 @@
208208
"truncated_exp1_targets": (
209209
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
210210
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
211-
"<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
211+
+ "<|endoftext|>" * 3
212+
+ "<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
212213
+ "<|endoftext|>" * 9
213-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer two<|endoftext|>"
214+
+ "<|endoftext|>" * 3
215+
+ "<think>\n\n</think>\n\nexample one answer two<|endoftext|>"
214216
),
215217
"truncated_exp1_targets_predictable": (
216218
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
217219
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
218-
"<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
220+
+ "<|endoftext|>" * 3
221+
+ "<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
219222
+ "<|endoftext|>" * 9
220-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer two<|endoftext|>"
223+
+ "<|endoftext|>" * 3
224+
+ "<think>\n\n</think>\n\nexample one answer two<|endoftext|>"
221225
),
222226
"packed_exp2_inputs": (
223227
"<|im_start|>user\nquestion two<|im_end|>\n"
@@ -227,15 +231,22 @@
227231
),
228232
"packed_exp2_targets": (
229233
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
230-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer two<|im_end|>\n"
234+
+ "<|endoftext|>" * 3
235+
+ "<think>\n\n</think>\n\nanswer two<|im_end|>\n"
231236
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
232-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer three<|im_end|>\n" + "!" * 14 + "<|endoftext|>"
237+
+ "<|endoftext|>" * 3
238+
+ "<think>\n\n</think>\n\nanswer three<|im_end|>\n"
239+
+ "!" * 14
240+
+ "<|endoftext|>"
233241
),
234242
"packed_exp2_targets_predictable": (
235243
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
236-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer two<|im_end|>\n"
244+
+ "<|endoftext|>" * 3
245+
+ "<think>\n\n</think>\n\nanswer two<|im_end|>\n"
237246
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
238-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer three<|im_end|>\n" + "<|endoftext|>" * 15
247+
+ "<|endoftext|>" * 3
248+
+ "<think>\n\n</think>\n\nanswer three<|im_end|>\n"
249+
+ "<|endoftext|>" * 15
239250
),
240251
},
241252
"prompt_completion": {
@@ -248,16 +259,20 @@
248259
),
249260
"truncated_exp1_targets": (
250261
"<|endoftext|>" * 8
251-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
262+
+ "<|endoftext|>" * 3
263+
+ "<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
252264
+ "<|endoftext|>" * 9
253-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer two<|im_end|>\n"
265+
+ "<|endoftext|>" * 3
266+
+ "<think>\n\n</think>\n\nexample one answer two<|im_end|>\n"
254267
+ "<|endoftext|>" * 7
255268
),
256269
"truncated_exp1_targets_predictable": (
257270
"<|endoftext|>" * 8
258-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
271+
+ "<|endoftext|>" * 3
272+
+ "<think>\n\n</think>\n\nexample one answer one<|im_end|>\n"
259273
+ "<|endoftext|>" * 9
260-
+ "<|im_start|>assistant\n<think>\n\n</think>\n\nexample one answer two<|im_end|>\n"
274+
+ "<|endoftext|>" * 3
275+
+ "<think>\n\n</think>\n\nexample one answer two<|im_end|>\n"
261276
+ "<|endoftext|>" * 7
262277
),
263278
"packed_exp2_inputs": (
@@ -268,15 +283,22 @@
268283
),
269284
"packed_exp2_targets": (
270285
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
271-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer two<|im_end|>\n"
286+
+ "<|endoftext|>" * 3
287+
+ "<think>\n\n</think>\n\nanswer two<|im_end|>\n"
272288
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
273-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer three<|im_end|>\n" + "!" * 14 + "<|endoftext|>"
289+
+ "<|endoftext|>" * 3
290+
+ "<think>\n\n</think>\n\nanswer three<|im_end|>\n"
291+
+ "!" * 14
292+
+ "<|endoftext|>"
274293
),
275294
"packed_exp2_targets_predictable": (
276295
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
277-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer two<|im_end|>\n"
296+
+ "<|endoftext|>" * 3
297+
+ "<think>\n\n</think>\n\nanswer two<|im_end|>\n"
278298
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
279-
"<|im_start|>assistant\n<think>\n\n</think>\n\nanswer three<|im_end|>\n" + "<|endoftext|>" * 15
299+
+ "<|endoftext|>" * 3
300+
+ "<think>\n\n</think>\n\nanswer three<|im_end|>\n"
301+
+ "<|endoftext|>" * 15
280302
),
281303
},
282304
}

0 commit comments

Comments
 (0)