Skip to content

Commit 1a44692

Browse files
Merge pull request #3067 from AI-Hypercomputer:xfgu-rl-sharding
PiperOrigin-RevId: 864988890
2 parents b43d692 + b62a65f commit 1a44692

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

src/MaxText/rl/train_rl.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,32 @@ def setup_configs_and_devices(argv: list[str]):
191191
for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices):
192192
sampler_devices.extend(devices_by_slice[slice_indices[i]])
193193

194+
trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices
195+
trainer_fsdp = trainer_devices_per_slice
196+
tp = config.ici_tensor_parallelism
197+
if tp > 1:
198+
if trainer_devices_per_slice % tp != 0:
199+
raise ValueError(
200+
f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})"
201+
)
202+
if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice:
203+
raise ValueError(
204+
f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal "
205+
f"devices_per_slice ({trainer_devices_per_slice})"
206+
)
207+
trainer_fsdp = trainer_devices_per_slice // tp
208+
194209
trainer_update = {
195210
"num_slices": config.num_trainer_slices,
196-
"ici_fsdp_parallelism": len(trainer_devices) // config.num_trainer_slices,
211+
"ici_fsdp_parallelism": trainer_fsdp,
212+
"ici_tensor_parallelism": tp,
197213
"dcn_data_parallelism": config.num_trainer_slices,
198214
}
199215

200216
sampler_update = {
201217
"num_slices": config.num_samplers_slices,
202218
"ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices,
219+
"ici_tensor_parallelism": -1,
203220
"dcn_data_parallelism": config.num_samplers_slices,
204221
}
205222

0 commit comments

Comments
 (0)