We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e04e78d commit 5b91824Copy full SHA for 5b91824
1 file changed
src/maxdiffusion/pyconfig.py
@@ -255,7 +255,7 @@ def user_init(raw_keys):
255
raw_keys["global_batch_size_to_train_on"],
256
) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
257
258
- if getattr(raw_keys, "vae_spatial", -1) == -1 or "vae_spatial" in raw_keys and raw_keys["vae_spatial"] == -1:
+ if raw_keys.get("vae_spatial", -1) == -1:
259
total_device = len(jax.devices())
260
dp = raw_keys.get("ici_data_parallelism", 1) * raw_keys.get("dcn_data_parallelism", 1)
261
if dp == -1 or dp == 0:
0 commit comments