@@ -170,30 +170,37 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
170170 - The `data_column_name` column will be updated to a list of
171171 messages, each formatted according to the tokenizer's chat template.
172172 - A new column named "is_prompt" will be added, where `True`
173- indicates a user message (prompt) and `False` indicates an assistant
173+ indicates a system message or a user message (prompt) and `False` indicates an assistant
174174 message (completion).
175175 """
176176 messages = []
177177 is_prompt = []
178- prompt = None
178+ round_msgs = []
179179 try :
180- for message in example [data_column_name ]:
181- if message ["role" ] == "user" :
182- prompt = message
180+ for idx , message in enumerate (example [data_column_name ]):
181+ if message ["role" ] == "system" :
182+ if idx != 0 :
183+ raise ValueError (f"System message found at index { idx } . System messages must be at index 0." )
184+ round_msgs .append (message )
185+ elif message ["role" ] == "user" :
186+ round_msgs .append (message )
183187 prompt_in_chat_template = tokenizer_model .apply_chat_template (
184- [ prompt ] , add_generation_prompt = False , tokenize = False
188+ round_msgs , add_generation_prompt = False , tokenize = False
185189 )
186190 messages .append (prompt_in_chat_template )
187191 is_prompt .append (True )
188192 elif message ["role" ] == "assistant" :
193+ round_msgs .append (message )
189194 prompt_completion_tokens = tokenizer_model .apply_chat_template (
190- [ prompt , message ] , add_generation_prompt = False , tokenize = True
195+ round_msgs , add_generation_prompt = False , tokenize = True
191196 )
192- prompt_tokens = tokenizer_model .apply_chat_template ([ prompt ], add_generation_prompt = False , tokenize = True )
197+ prompt_tokens = tokenizer_model .apply_chat_template (round_msgs [: - 1 ], add_generation_prompt = False , tokenize = True )
193198 completion_tokens = prompt_completion_tokens [len (prompt_tokens ) :]
194199 completion_in_chat_template = tokenizer_model .decode (completion_tokens , skip_special_tokens = False )
195200 messages .append (completion_in_chat_template )
196201 is_prompt .append (False )
202+ # Round ended, clearing the buffer.
203+ round_msgs .clear ()
197204 except ValueError as e :
198205 max_logging .log (f"Unable to apply chat template: { e } " )
199206 raise e
@@ -688,6 +695,7 @@ def shift_left(x, pad_id, axis=1):
688695def shift_and_refine (x , ignored_ids , axis = 1 ):
689696 """Shift inputs, set segmentation to 0 when target element is in ignored_ids if provided"""
690697 x ["targets" ] = shift_left (x ["targets" ], ignored_ids [0 ], axis = axis )
698+ x ["targets_segmentation" ] = shift_left (x ["targets_segmentation" ], 0 , axis = axis )
691699 for ignore_id in ignored_ids :
692700 x ["targets_segmentation" ] = np .where (x ["targets" ] != ignore_id , x ["targets_segmentation" ], 0 )
693701
0 commit comments