Skip to content

Commit 5b91824

Browse files
committed
Merged CFG cache, 220 sec using tokamax_flash
1 parent e04e78d commit 5b91824

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def user_init(raw_keys):
255255
raw_keys["global_batch_size_to_train_on"],
256256
) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
257257

258-
if getattr(raw_keys, "vae_spatial", -1) == -1 or "vae_spatial" in raw_keys and raw_keys["vae_spatial"] == -1:
258+
if raw_keys.get("vae_spatial", -1) == -1:
259259
total_device = len(jax.devices())
260260
dp = raw_keys.get("ici_data_parallelism", 1) * raw_keys.get("dcn_data_parallelism", 1)
261261
if dp == -1 or dp == 0:

0 commit comments

Comments
 (0)