Skip to content

Commit e7926f4

Browse files
committed
Fix for double pipeline loading
1 parent 5f458fd commit e7926f4

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,15 @@ def run(config, pipeline=None, filename_prefix=""):
149149
WanCheckpointer = checkpointer_lib.WanCheckpointer
150150

151151
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
152-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
153-
154152
if pipeline is None:
155-
pipeline_lib = get_pipeline(model_key)
156-
WanPipeline = pipeline_lib.WanPipeline
157-
pipeline = WanPipeline.from_pretrained(config)
153+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
154+
155+
if pipeline is None:
156+
pipeline_lib = get_pipeline(model_key)
157+
WanPipeline = pipeline_lib.WanPipeline
158+
pipeline = WanPipeline.from_pretrained(config)
159+
else:
160+
max_logging.log("Using provided pipeline for inference.")
158161
s0 = time.perf_counter()
159162

160163
# Using global_batch_size_to_train_on so not to create more config variables

0 commit comments

Comments
 (0)