@@ -191,15 +191,32 @@ def setup_configs_and_devices(argv: list[str]):
191191 for i in range (config .num_trainer_slices , config .num_trainer_slices + config .num_samplers_slices ):
192192 sampler_devices .extend (devices_by_slice [slice_indices [i ]])
193193
194+ trainer_devices_per_slice = len (trainer_devices ) // config .num_trainer_slices
195+ trainer_fsdp = trainer_devices_per_slice
196+ tp = config .ici_tensor_parallelism
197+ if tp > 1 :
198+ if trainer_devices_per_slice % tp != 0 :
199+ raise ValueError (
200+ f"trainer_devices_per_slice ({ trainer_devices_per_slice } ) must be divisible by tensor parallelism ({ tp } )"
201+ )
202+ if config .ici_fsdp_parallelism != - 1 and config .ici_fsdp_parallelism * tp != trainer_devices_per_slice :
203+ raise ValueError (
204+ f"ici_fsdp_parallelism ({ config .ici_fsdp_parallelism } ) * ici_tensor_parallelism ({ tp } ) must equal "
205+ f"devices_per_slice ({ trainer_devices_per_slice } )"
206+ )
207+ trainer_fsdp = trainer_devices_per_slice // tp
208+
194209 trainer_update = {
195210 "num_slices" : config .num_trainer_slices ,
196- "ici_fsdp_parallelism" : len (trainer_devices ) // config .num_trainer_slices ,
211+ "ici_fsdp_parallelism" : trainer_fsdp ,
212+ "ici_tensor_parallelism" : tp ,
197213 "dcn_data_parallelism" : config .num_trainer_slices ,
198214 }
199215
200216 sampler_update = {
201217 "num_slices" : config .num_samplers_slices ,
202218 "ici_fsdp_parallelism" : len (sampler_devices ) // config .num_samplers_slices ,
219+ "ici_tensor_parallelism" : - 1 ,
203220 "dcn_data_parallelism" : config .num_samplers_slices ,
204221 }
205222
0 commit comments