@@ -54,7 +54,13 @@ def vision_sft_preprocessing_pipeline(
5454 """pipeline for multimodal SFT with HF dataset"""
5555
5656 assert len (text_columns ) == 2 , f"Need two text_columns for query and response, received { text_columns = } "
57- batch_size = global_batch_size // jax .process_count ()
57+ # Tunix GA requires per-micro-batch slicing at the data level,
58+ # whereas Native GA processes the full batch and splits it internally.
59+ if config .use_tunix_gradient_accumulation :
60+ batch_size = global_batch_size // jax .process_count () // config .gradient_accumulation_steps
61+ else :
62+ batch_size = global_batch_size // jax .process_count ()
63+
5864 if config .enable_data_shuffling :
5965 dataset = dataset .shuffle (seed = config .data_shuffle_seed )
6066
@@ -195,13 +201,21 @@ def preprocessing_pipeline(
195201 generate_padding_batch = False ,
196202 use_dpo = None ,
197203 use_sft = None ,
204+ use_tunix_gradient_accumulation = False ,
205+ num_microbatches = 1 ,
198206 sft_train_on_completion_only = True ,
199207 grain_worker_count = 1 , # only support 0 or 1
200208 max_segments_per_seq = None ,
201209):
202210 """pipeline for preprocessing HF dataset"""
203211
204212 assert global_batch_size % global_mesh .size == 0 , "Batch size should be divisible by number of global devices."
213+ # Tunix GA requires per-micro-batch slicing at the data level,
214+ # whereas Native GA processes the full batch and splits it internally.
215+ if use_tunix_gradient_accumulation :
216+ batch_size = global_batch_size // jax .process_count () // num_microbatches
217+ else :
218+ batch_size = global_batch_size // jax .process_count ()
205219
206220 if shuffle :
207221 dataset = dataset .shuffle (seed = data_shuffle_seed )
@@ -303,15 +317,15 @@ def lists2array(x):
303317 max_segments = None
304318 operations .append (
305319 grain .experimental .PackAndBatchOperation (
306- batch_size = global_batch_size // jax . process_count () ,
320+ batch_size = batch_size ,
307321 length_struct = length_struct ,
308322 max_sequences_per_bin = max_segments ,
309323 )
310324 )
311325 operations .append (_input_pipeline_utils .ReformatPacking (data_column_names ))
312326 else :
313327 operations .append (_input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , pad_id ))
314- operations .append (grain .Batch (batch_size = global_batch_size // jax . process_count () , drop_remainder = drop_remainder ))
328+ operations .append (grain .Batch (batch_size = batch_size , drop_remainder = drop_remainder ))
315329
316330 if shift and not use_dpo :
317331 operations .append (_input_pipeline_utils .ShiftData (ignored_ids = [pad_id , tokenizer .bos_token_id ], axis = 1 ))
@@ -390,6 +404,8 @@ def make_hf_train_iterator(
390404 generate_padding_batch = config .generate_padding_batch_train ,
391405 use_dpo = config .use_dpo ,
392406 use_sft = config .use_sft ,
407+ use_tunix_gradient_accumulation = config .use_tunix_gradient_accumulation ,
408+ num_microbatches = config .gradient_accumulation_steps ,
393409 sft_train_on_completion_only = config .sft_train_on_completion_only ,
394410 chat_template_path = config .chat_template_path ,
395411 max_segments_per_seq = config .max_segments_per_seq ,
@@ -443,6 +459,7 @@ def make_hf_eval_iterator(
443459 generate_padding_batch = config .generate_padding_batch_eval ,
444460 use_dpo = config .use_dpo ,
445461 use_sft = config .use_sft ,
462+ num_microbatches = config .gradient_accumulation_steps ,
446463 sft_train_on_completion_only = config .sft_train_on_completion_only ,
447464 chat_template_path = config .chat_template_path ,
448465 max_segments_per_seq = config .max_segments_per_seq ,
0 commit comments