Skip to content

Commit 2b8549a

Browse files
committed
fixes ssim.
1 parent db4caf0 commit 2b8549a

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ def print_ssim(pretrained_video_path, posttrained_video_path):
4949
pretrained_video = video_processor.preprocess_video(pretrained_video)
5050
pretrained_video = np.array(pretrained_video)
5151
pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1))
52+
pretrained_video = np.uint8(255 * pretrained_video)
5253

5354
posttrained_video = load_video(posttrained_video_path[0])
5455
posttrained_video = video_processor.preprocess_video(posttrained_video)
5556
posttrained_video = np.array(posttrained_video)
5657
posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1))
58+
posttrained_video = np.uint8(255 * posttrained_video)
59+
5760
ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255)
5861

5962
max_logging.log(f"SSIM score after training is {ssim_compare}")

0 commit comments

Comments
 (0)