Skip to content

Commit a583a7d

Browse files
committed
fix to avoid multiple loading
1 parent fa1871e commit a583a7d

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,21 @@ def run(config, pipeline=None, filename_prefix=""):
145145
if jax.process_index() == 0 and writer:
146146
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
147147

148-
checkpointer_lib = get_checkpointer(model_key)
149-
WanCheckpointer = checkpointer_lib.WanCheckpointer
148+
if pipeline is None:
149+
checkpointer_lib = get_checkpointer(model_key)
150+
WanCheckpointer = checkpointer_lib.WanCheckpointer
150151

151-
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
152-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
152+
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
153+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
153154

154-
if pipeline is None:
155-
pipeline_lib = get_pipeline(model_key)
156-
WanPipeline = pipeline_lib.WanPipeline
157-
pipeline = WanPipeline.from_pretrained(config)
155+
if pipeline is None:
156+
pipeline_lib = get_pipeline(model_key)
157+
WanPipeline = pipeline_lib.WanPipeline
158+
pipeline = WanPipeline.from_pretrained(config)
159+
else:
160+
max_logging.log("Checkpoint loaded successfully.")
161+
else:
162+
max_logging.log("Using provided pipeline in generate_wan.run.")
158163
s0 = time.perf_counter()
159164

160165
# Using global_batch_size_to_train_on so not to create more config variables

0 commit comments

Comments
 (0)