Skip to content

Commit 64a404d

Browse files
committed
add generate video during training
1 parent 39b227b commit 64a404d

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,34 @@
1919
from maxdiffusion import pyconfig, max_logging, max_utils
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
22+
import os
2223

2324
jax.config.update("jax_use_shardy_partitioner", True)
2425

26+
def inference_generate_video(config, pipeline, filename_prefix=""):
27+
s0 = time.perf_counter()
28+
prompt = [config.prompt] * config.global_batch_size_to_train_on
29+
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
30+
31+
max_logging.log(
32+
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
33+
)
34+
35+
videos = pipeline(
36+
prompt=prompt,
37+
negative_prompt=negative_prompt,
38+
height=config.height,
39+
width=config.width,
40+
num_frames=config.num_frames,
41+
num_inference_steps=config.num_inference_steps,
42+
guidance_scale=config.guidance_scale,
43+
)
44+
45+
print(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
46+
for i in range(len(videos)):
47+
video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
48+
export_to_video(videos[i], video_path, fps=config.fps)
49+
return
2550

2651
def run(config, pipeline=None, filename_prefix=""):
2752
print("seed: ", config.seed)
@@ -54,7 +79,7 @@ def run(config, pipeline=None, filename_prefix=""):
5479
print("compile time: ", (time.perf_counter() - s0))
5580
saved_video_path = []
5681
for i in range(len(videos)):
57-
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
82+
video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
5883
export_to_video(videos[i], video_path, fps=config.fps)
5984
saved_video_path.append(video_path)
6085

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from skimage.metrics import structural_similarity as ssim
3838
from flax.training import train_state
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
from maxdiffusion.generate_wan import inference_generate_video
4041

4142

4243
class TrainState(train_state.TrainState):
@@ -251,6 +252,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
251252
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
252253
# Re-create the iterator each time you start evaluation to reset it
253254
# This assumes your data loading logic can be called to get a fresh iterator.
255+
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
254256
eval_data_iterator = self.load_dataset(mesh, is_training=False)
255257
eval_rng = jax.random.key(self.config.seed + step)
256258
eval_metrics = []

0 commit comments

Comments
 (0)