1515from typing import List , Union , Optional
1616from functools import partial
1717import numpy as np
18+ import os
1819import jax
1920import jax .numpy as jnp
2021from jax .sharding import Mesh , NamedSharding , PartitionSpec as P
3435from transformers import AutoTokenizer , UMT5EncoderModel
3536from maxdiffusion .utils .import_utils import is_ftfy_available
3637from maxdiffusion .maxdiffusion_utils import get_dummy_wan_inputs
38+ from ...visualization import VisualizationMixin , create_visualization_videos
3739import html
3840import re
3941import torch
@@ -172,7 +174,7 @@ def create_sharded_logical_model(model, logical_axis_rules):
172174 return model
173175
174176
175- class WanPipeline :
177+ class WanPipeline ( VisualizationMixin ) :
176178 r"""
177179 Pipeline for text-to-video generation using Wan.
178180
@@ -539,6 +541,157 @@ def prepare_latents(
539541
540542 return latents
541543
544+ def visualize_frame (
545+ self ,
546+ hidden_states : jnp .ndarray ,
547+ timestep : int ,
548+ frame_idx : int = 0
549+ ) -> None :
550+ """
551+ Generate visualization for a specific frame at timestep t.
552+ Creates two images: noise representation and current decoded image.
553+
554+ Args:
555+ hidden_states: Current latent states with shape (batch, channels, frames, height, width)
556+ timestep: Current timestep value
557+ frame_idx: Which frame to visualize (default: 0 for first frame)
558+ """
559+ if not self ._should_visualize ("frame_debug" ):
560+ return
561+
562+ try :
563+ max_logging .log (f"Visualizing frame { frame_idx } at timestep { timestep } " )
564+
565+ # Convert to numpy and get statistics
566+ hidden_states_np = np .array (hidden_states )
567+ batch_size , num_channels , num_frames , height , width = hidden_states_np .shape
568+
569+ # Ensure frame_idx is valid
570+ frame_idx = min (frame_idx , num_frames - 1 )
571+
572+ # Extract the specific frame from first batch
573+ frame_latents = hidden_states_np [0 , :, frame_idx , :, :] # (channels, height, width)
574+
575+ # Save latent tensor statistics and raw data
576+ viz_dir = self ._get_visualization_dir ("frame_debug" )
577+ latent_filename = f"latent_t{ timestep } _frame{ frame_idx } "
578+
579+ self ._save_tensor_stats (frame_latents , latent_filename , "frame_debug" )
580+ latent_path = self ._save_tensor_as_numpy (frame_latents , latent_filename , "frame_debug" )
581+
582+ # 1. Create noise visualization (latent space representation)
583+ noise_image_path = os .path .join (viz_dir , f"noise_t{ timestep } _frame{ frame_idx } .png" )
584+
585+ # Show just the first channel as a single image (like how video is parsed)
586+ channel_0 = frame_latents [0 ] # (height, width) - first channel only
587+
588+ # Create single image plot for the latent channel with fixed dimensions
589+ try :
590+ import matplotlib .pyplot as plt
591+
592+ # Use fixed figure size and DPI for consistent output dimensions
593+ fig , ax = plt .subplots (figsize = (10 , 8 ))
594+
595+ # Display the latent channel
596+ im = ax .imshow (channel_0 , cmap = 'viridis' )
597+ ax .set_title (f"Latent Channel 0 at t={ timestep } , Frame { frame_idx } " )
598+ ax .axis ('off' )
599+
600+ # Add colorbar with fixed positioning
601+ cbar = plt .colorbar (im , ax = ax , fraction = 0.046 , pad = 0.04 )
602+ cbar .set_label ('Latent Value' )
603+
604+ # Save with fixed dimensions (no bbox_inches='tight' to avoid size variations)
605+ plt .savefig (noise_image_path , dpi = 150 , bbox_inches = None ,
606+ facecolor = 'white' , edgecolor = 'none' )
607+ plt .close ()
608+
609+ max_logging .log (f"Saved noise visualization to { noise_image_path } " )
610+
611+ except ImportError :
612+ max_logging .log ("matplotlib not available, skipping noise visualization" )
613+
614+ # 2. Create current image visualization (VAE decoded)
615+ current_image_path = os .path .join (viz_dir , f"current_image_t{ timestep } _frame{ frame_idx } .png" )
616+
617+ # Decode through VAE
618+ latents_for_decode = hidden_states [:, :, frame_idx :frame_idx + 1 , :, :] # Keep single frame with proper dims
619+
620+ # Apply VAE scaling
621+ latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
622+ latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
623+ scaled_latents = latents_for_decode / latents_std + latents_mean
624+ scaled_latents = scaled_latents .astype (jnp .float32 )
625+
626+ # Decode through VAE
627+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
628+ video = self .vae .decode (scaled_latents , self .vae_cache )[0 ]
629+
630+ # Process video
631+ video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
632+ video = jax .experimental .multihost_utils .process_allgather (video , tiled = True )
633+ video = torch .from_numpy (np .array (video .astype (dtype = jnp .float32 ))).to (dtype = torch .bfloat16 )
634+ video = self .video_processor .postprocess_video (video , output_type = "np" )
635+
636+ # Extract the decoded frame
637+ video_np = np .array (video )
638+ decoded_frame = video_np [0 , 0 ] # (height, width, channels) - first batch, first frame
639+
640+ # Handle different channel formats for display
641+ if decoded_frame .shape [- 1 ] == 3 :
642+ # RGB
643+ display_frame = np .clip (decoded_frame , 0 , 1 )
644+ cmap = None
645+ else :
646+ # Convert to grayscale
647+ display_frame = decoded_frame .mean (axis = - 1 ) if len (decoded_frame .shape ) == 3 else decoded_frame
648+ cmap = 'gray'
649+
650+ # Create single image plot with fixed dimensions
651+ try :
652+ import matplotlib .pyplot as plt
653+
654+ # Use fixed figure size for consistent output dimensions
655+ fig , ax = plt .subplots (figsize = (10 , 10 ))
656+
657+ ax .imshow (display_frame , cmap = cmap )
658+ ax .set_title (f"Decoded Frame { frame_idx } at t={ timestep } " )
659+ ax .axis ('off' )
660+
661+ # Save with fixed dimensions (no bbox_inches='tight' to avoid size variations)
662+ plt .savefig (current_image_path , dpi = 150 , bbox_inches = None ,
663+ facecolor = 'white' , edgecolor = 'none' )
664+ plt .close ()
665+
666+ max_logging .log (f"Saved current image to { current_image_path } " )
667+
668+ except ImportError :
669+ max_logging .log ("matplotlib not available for current image visualization" )
670+
671+ max_logging .log ("Frame visualization complete:" )
672+ max_logging .log (f" - Noise representation: { noise_image_path } " )
673+ max_logging .log (f" - Current decoded image: { current_image_path } " )
674+ max_logging .log (f" - Raw latents: { latent_path } " )
675+
676+ except Exception as e :
677+ max_logging .log (f"Error in frame visualization: { e } " )
678+
679+ def _create_visualization_videos (self ) -> None :
680+ """
681+ Create denoising and noise evolution videos from the visualization output.
682+ Uses 4 fps (0.25 seconds per frame) by default.
683+ """
684+ try :
685+ viz_dir = self ._get_visualization_dir ("frame_debug" )
686+ if os .path .exists (viz_dir ):
687+ max_logging .log ("Creating visualization videos..." )
688+ # Use 4 fps for 0.25 seconds per frame (default)
689+ create_visualization_videos (viz_dir , output_prefix = "wan_visualization" , fps = 4.0 )
690+ else :
691+ max_logging .log (f"Visualization directory not found: { viz_dir } " )
692+ except Exception as e :
693+ max_logging .log (f"Error creating visualization videos: { e } " )
694+
542695 def __call__ (
543696 self ,
544697 prompt : Union [str , List [str ]] = None ,
@@ -603,6 +756,11 @@ def __call__(
603756 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
604757 )
605758
759+ # Visualize initial noise if enabled - using actual first timestep
760+ if self ._should_visualize ("frame_debug" ):
761+ initial_timestep = int (scheduler_state .timesteps [0 ])
762+ self .visualize_frame (latents , timestep = initial_timestep , frame_idx = 0 )
763+
606764 graphdef , state , rest_of_state = nnx .split (self .transformer , nnx .Param , ...)
607765
608766 p_run_inference = partial (
@@ -612,6 +770,8 @@ def __call__(
612770 scheduler = self .scheduler ,
613771 scheduler_state = scheduler_state ,
614772 num_transformer_layers = self .transformer .config .num_layers ,
773+ should_visualize = self ._should_visualize ("frame_debug" ),
774+ visualize_fn = self .visualize_frame if self ._should_visualize ("frame_debug" ) else None ,
615775 )
616776
617777 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -623,6 +783,13 @@ def __call__(
623783 prompt_embeds = prompt_embeds ,
624784 negative_prompt_embeds = negative_prompt_embeds ,
625785 )
786+
787+ # Visualize final denoised latents if enabled
788+ if self ._should_visualize ("frame_debug" ):
789+ # Use the last (lowest) timestep from the scheduler
790+ final_timestep = int (scheduler_state .timesteps [- 1 ])
791+ self .visualize_frame (latents , timestep = final_timestep , frame_idx = 0 )
792+
626793 latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
627794 latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
628795 latents = latents / latents_std + latents_mean
@@ -635,6 +802,11 @@ def __call__(
635802 video = jax .experimental .multihost_utils .process_allgather (video , tiled = True )
636803 video = torch .from_numpy (np .array (video .astype (dtype = jnp .float32 ))).to (dtype = torch .bfloat16 )
637804 video = self .video_processor .postprocess_video (video , output_type = "np" )
805+
806+ # Generate visualization videos if frame debugging is enabled
807+ if not vae_only and self ._should_visualize ("frame_debug" ):
808+ self ._create_visualization_videos ()
809+
638810 return video
639811
640812
@@ -673,6 +845,8 @@ def run_inference(
673845 scheduler : FlaxUniPCMultistepScheduler ,
674846 num_transformer_layers : int ,
675847 scheduler_state ,
848+ should_visualize : bool = False ,
849+ visualize_fn = None ,
676850):
677851 do_classifier_free_guidance = guidance_scale > 1.0
678852 if do_classifier_free_guidance :
@@ -695,4 +869,8 @@ def run_inference(
695869 )
696870
697871 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
872+ # Visualize latents at each step if enabled
873+ if should_visualize and visualize_fn is not None :
874+ visualize_fn (latents , timestep = int (t ), frame_idx = 0 )
875+
698876 return latents
0 commit comments