Skip to content

Commit d6adbd6

Browse files
author
Juan Acevedo
committed
generates ssim score between pretrained and trained model.
1 parent 31dcb6c commit d6adbd6

2 files changed

Lines changed: 40 additions & 12 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import jax
1717
import time
1818
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
19-
from maxdiffusion import pyconfig, max_logging
19+
from maxdiffusion import pyconfig, max_logging, max_utils
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

@@ -59,8 +59,12 @@ def run(config, pipeline=None, filename_prefix=""):
5959
)
6060

6161
print("compile time: ", (time.perf_counter() - s0))
62+
saved_video_path = []
6263
for i in range(len(videos)):
63-
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
64+
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
65+
export_to_video(videos[i], video_path, fps=config.fps)
66+
saved_video_path.append(video_path)
67+
6468
s0 = time.perf_counter()
6569
videos = pipeline(
6670
prompt=prompt,
@@ -74,12 +78,11 @@ def run(config, pipeline=None, filename_prefix=""):
7478
slg_start=slg_start,
7579
slg_end=slg_end,
7680
)
77-
print("generation time: ", (time.perf_counter() - s0))
78-
for i in range(len(videos)):
79-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
81+
print("compile time: ", (time.perf_counter() - s0))
8082

8183
s0 = time.perf_counter()
82-
with jax.profiler.trace("/tmp/trace/"):
84+
if config.enable_profiler:
85+
max_utils.activate_profiler(config)
8386
videos = pipeline(
8487
prompt=prompt,
8588
negative_prompt=negative_prompt,
@@ -92,7 +95,9 @@ def run(config, pipeline=None, filename_prefix=""):
9295
slg_start=slg_start,
9396
slg_end=slg_end,
9497
)
95-
print("generation time: ", (time.perf_counter() - s0))
98+
max_utils.deactivate_profiler(config)
99+
print("generation time: ", (time.perf_counter() - s0))
100+
return saved_video_path
96101

97102

98103
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,32 @@
3131
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
3232
from maxdiffusion.generate_wan import run as generate_wan
3333
from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue)
34+
from ..video_processor import VideoProcessor
35+
from ..utils import load_video
36+
from skimage.metrics import structural_similarity as ssim
3437

3538

3639
def generate_sample(config, pipeline, filename_prefix):
3740
"""
3841
Generates a video to validate training did not corrupt the model
3942
"""
40-
generate_wan(config, pipeline, filename_prefix)
43+
return generate_wan(config, pipeline, filename_prefix)
44+
45+
46+
def print_ssim(pretrained_video_path, posttrained_video_path):
47+
video_processor = VideoProcessor()
48+
pretrained_video = load_video(pretrained_video_path[0])
49+
pretrained_video = video_processor.preprocess_video(pretrained_video)
50+
pretrained_video = np.array(pretrained_video)
51+
pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1))
52+
53+
posttrained_video = load_video(posttrained_video_path[0])
54+
posttrained_video = video_processor.preprocess_video(posttrained_video)
55+
posttrained_video = np.array(posttrained_video)
56+
posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1))
57+
ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255)
58+
59+
max_logging.log(f"SSIM score after training is {ssim_compare}")
4160

4261

4362
class WanTrainer(WanCheckpointer):
@@ -105,7 +124,7 @@ def start_training(self):
105124
# del pipeline.vae
106125

107126
# Generate a sample before training to compare against generated sample after training.
108-
generate_sample(self.config, pipeline, filename_prefix="pre-training-")
127+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
109128
mesh = pipeline.mesh
110129
data_iterator = self.load_dataset(mesh)
111130

@@ -115,7 +134,12 @@ def start_training(self):
115134
pipeline.scheduler_state = scheduler_state
116135

117136
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
118-
self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
137+
138+
# Returns pipeline with trained transformer state
139+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
140+
141+
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
142+
print_ssim(pretrained_video_path, posttrained_video_path)
119143

120144
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
121145

@@ -189,8 +213,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
189213
# load new state for trained tranformer
190214
graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
191215
pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state)
192-
193-
generate_sample(self.config, pipeline, filename_prefix="post-training-")
216+
return pipeline
194217

195218

196219
def train_step(state, graphdef, scheduler_state, data, rng, scheduler):

0 commit comments

Comments
 (0)