Skip to content

Commit 4b2889c

Browse files
committed
Feat: Add WAN 2.1 step-by-step denoising visualization with automatic video generation
1 parent f28ef83 commit 4b2889c

8 files changed

Lines changed: 752 additions & 2 deletions

File tree

VIZ_README.md

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# WAN 2.1 Visualization System
2+
3+
A visualization system for debugging and understanding the WAN 2.1 diffusion denoising process.
4+
5+
## Quick Start
6+
7+
1. **Enable visualization** in the config:
8+
```yaml
9+
# In src/maxdiffusion/configs/base_wan_14b.yml
10+
visualize_frame_debug: true
11+
visualization_output_dir: "wan_visualization_output"
12+
# optional:
13+
save_tensor_stats: True # Save tensor statistics to JSON files
14+
```
15+
16+
2. **Run inference**:
17+
```bash
18+
export RUN_NAME=wan21-8tpu
19+
export LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
20+
export YOUR_GCS_BUCKET=gs://${USR_NAME}-wan-maxdiffusion
21+
22+
export OUTPUT_DIR=${YOUR_GCS_BUCKET}/wan/${RUN_NAME}
23+
export DATASET_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/train/
24+
export EVAL_DATA_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/eval_timesteps/
25+
export SAVE_DATASET_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/save/
26+
27+
export RANDOM=123456789
28+
export IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly:2025-10-27
29+
export LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
30+
31+
export HUGGINGFACE_HUB_CACHE=/dev/shm
32+
33+
echo 'Starting WAN inference ...' && \
34+
python src/maxdiffusion/generate_wan.py \
35+
src/maxdiffusion/configs/base_wan_14b.yml \
36+
enable_jax_named_scopes=False \
37+
attention='flash' \
38+
weights_dtype=bfloat16 \
39+
activations_dtype=bfloat16 \
40+
guidance_scale=5.0 \
41+
flow_shift=3.0 \
42+
fps=24 \
43+
skip_jax_distributed_system=True \
44+
run_name='test-wan-training-new' \
45+
output_dir=${OUTPUT_DIR} \
46+
load_tfrecord_cached=True \
47+
height=720 \
48+
width=1280 \
49+
num_frames=81 \
50+
num_inference_steps=50 \
51+
prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
52+
negative_prompt='low quality, over exposure.' \
53+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
54+
max_train_steps=20000 \
55+
enable_profiler=True \
56+
dataset_save_location=${SAVE_DATASET_DIR} \
57+
remat_policy='FULL' \
58+
flash_min_seq_length=0 \
59+
seed=$RANDOM \
60+
skip_first_n_steps_for_profiler=3 \
61+
profiler_steps=3 \
62+
per_device_batch_size=0.125 \
63+
allow_split_physical_axes=True \
64+
ici_data_parallelism=2 \
65+
ici_fsdp_parallelism=2 \
66+
ici_tensor_parallelism=2
67+
68+
echo 'WAN inference completed. Output saved to '${OUTPUT_DIR}
69+
```
70+
71+
3. **View outputs** in `wan_visualization_output/frame_debug/`:
72+
```
73+
wan_visualization_output/
74+
├── frame_debug/
75+
│ ├── noise_t999_frame0.png # Latent channel 0 at each timestep
76+
│ ├── current_image_t999_frame0.png # VAE-decoded image at each timestep
77+
│ ├── ... # One pair per denoising step
78+
│ ├── wan_visualization_denoising_process.mp4 # Video: noise → final image
79+
│ └── wan_visualization_noise_evolution.mp4 # Video: latent evolution
80+
```
81+
82+
## Output Examples
83+
84+
The system automatically generates two videos showing the complete denoising process:
85+
86+
```bash
87+
$ ls wan_visualization_output/*.mp4
88+
wan_visualization_output/wan_visualization_denoising_process.mp4
89+
wan_visualization_output/wan_visualization_noise_evolution.mp4
90+
```
91+
92+
**Videos** (4 fps, 0.25s per frame):
93+
- `denoising_process.mp4`: Shows VAE-decoded images evolving from noise to final result
94+
- `noise_evolution.mp4`: Shows raw latent space (channel 0) evolution during denoising
95+
96+
**Individual Frames**: 50+ timestamped PNG files showing step-by-step progression
97+
98+
## What We Built
99+
100+
### 1. **VisualizationMixin Architecture**
101+
- Reusable base class in `src/maxdiffusion/visualization/base_mixin.py`
102+
- Common utilities: file I/O, plotting, statistics
103+
- Used by `WanPipeline(VisualizationMixin)`
104+
105+
### 2. **Step-by-Step Visualization**
106+
- **Modified**: `src/maxdiffusion/pipelines/wan/wan_pipeline.py`
107+
- **Added**: `visualize_frame()` method with automatic calls during inference
108+
- **Captures**: Both latent space and VAE-decoded representations at each timestep
109+
110+
### 3. **Automatic Video Generation**
111+
- **Added**: `src/maxdiffusion/visualization/video_utils.py`
112+
- **Uses**: Same `imageio` method as WAN 2.1's video export
113+
- **Creates**: Two videos automatically after inference completes
114+
115+
### 4. **Configuration-Driven**
116+
- **Control**: `visualize_frame_debug: true/false` in config files
117+
- **Output**: Configurable directory via `visualization_output_dir`
118+
119+
## Key Features
120+
121+
- **Zero overhead** when disabled (config-controlled)
122+
- **Consistent sizing** (fixed matplotlib dimensions prevent video corruption)
123+
- **Complete timeline** (50 timesteps from t=999 → t=57)
124+
- **Automatic integration** (no separate scripts needed)
125+
- **WAN-compatible** (uses same video export method as WAN 2.1)
126+
127+
## Technical Details
128+
129+
- **Latent visualization**: Shows channel 0 (10x8 matplotlib figure)
130+
- **Image visualization**: VAE-decoded RGB frames (10x10 matplotlib figure)
131+
- **Video format**: MP4, 4 fps, imageio with quality=8
132+
- **File naming**: `{type}_t{timestep}_frame{frame_idx}.png`
133+
- **Statistics**: JSON files with tensor stats (shape, dtype, mean, std, etc.)
134+
135+
The system provides complete visibility into WAN's denoising process, from initial Gaussian noise to final coherent video frames.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,11 @@ eval_data_dir: ""
352352
enable_generate_video_for_eval: False # This will increase the used TPU memory.
353353
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
354354

355-
enable_ssim: False
355+
enable_ssim: False
356+
357+
# Enable frame-by-frame debugging visualization
358+
visualize_frame_debug: True
359+
360+
# Visualization output settings
361+
visualization_output_dir: "wan_visualization_output" # Will be created if it doesn't exist
362+
save_tensor_stats: True # Save tensor statistics to JSON files

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import List, Union, Optional
1616
from functools import partial
1717
import numpy as np
18+
import os
1819
import jax
1920
import jax.numpy as jnp
2021
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
@@ -34,6 +35,7 @@
3435
from transformers import AutoTokenizer, UMT5EncoderModel
3536
from maxdiffusion.utils.import_utils import is_ftfy_available
3637
from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs
38+
from ...visualization import VisualizationMixin, create_visualization_videos
3739
import html
3840
import re
3941
import torch
@@ -172,7 +174,7 @@ def create_sharded_logical_model(model, logical_axis_rules):
172174
return model
173175

174176

175-
class WanPipeline:
177+
class WanPipeline(VisualizationMixin):
176178
r"""
177179
Pipeline for text-to-video generation using Wan.
178180
@@ -539,6 +541,157 @@ def prepare_latents(
539541

540542
return latents
541543

544+
def visualize_frame(
545+
self,
546+
hidden_states: jnp.ndarray,
547+
timestep: int,
548+
frame_idx: int = 0
549+
) -> None:
550+
"""
551+
Generate visualization for a specific frame at timestep t.
552+
Creates two images: noise representation and current decoded image.
553+
554+
Args:
555+
hidden_states: Current latent states with shape (batch, channels, frames, height, width)
556+
timestep: Current timestep value
557+
frame_idx: Which frame to visualize (default: 0 for first frame)
558+
"""
559+
if not self._should_visualize("frame_debug"):
560+
return
561+
562+
try:
563+
max_logging.log(f"Visualizing frame {frame_idx} at timestep {timestep}")
564+
565+
# Convert to numpy and get statistics
566+
hidden_states_np = np.array(hidden_states)
567+
batch_size, num_channels, num_frames, height, width = hidden_states_np.shape
568+
569+
# Ensure frame_idx is valid
570+
frame_idx = min(frame_idx, num_frames - 1)
571+
572+
# Extract the specific frame from first batch
573+
frame_latents = hidden_states_np[0, :, frame_idx, :, :] # (channels, height, width)
574+
575+
# Save latent tensor statistics and raw data
576+
viz_dir = self._get_visualization_dir("frame_debug")
577+
latent_filename = f"latent_t{timestep}_frame{frame_idx}"
578+
579+
self._save_tensor_stats(frame_latents, latent_filename, "frame_debug")
580+
latent_path = self._save_tensor_as_numpy(frame_latents, latent_filename, "frame_debug")
581+
582+
# 1. Create noise visualization (latent space representation)
583+
noise_image_path = os.path.join(viz_dir, f"noise_t{timestep}_frame{frame_idx}.png")
584+
585+
# Show just the first channel as a single image (like how video is parsed)
586+
channel_0 = frame_latents[0] # (height, width) - first channel only
587+
588+
# Create single image plot for the latent channel with fixed dimensions
589+
try:
590+
import matplotlib.pyplot as plt
591+
592+
# Use fixed figure size and DPI for consistent output dimensions
593+
fig, ax = plt.subplots(figsize=(10, 8))
594+
595+
# Display the latent channel
596+
im = ax.imshow(channel_0, cmap='viridis')
597+
ax.set_title(f"Latent Channel 0 at t={timestep}, Frame {frame_idx}")
598+
ax.axis('off')
599+
600+
# Add colorbar with fixed positioning
601+
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
602+
cbar.set_label('Latent Value')
603+
604+
# Save with fixed dimensions (no bbox_inches='tight' to avoid size variations)
605+
plt.savefig(noise_image_path, dpi=150, bbox_inches=None,
606+
facecolor='white', edgecolor='none')
607+
plt.close()
608+
609+
max_logging.log(f"Saved noise visualization to {noise_image_path}")
610+
611+
except ImportError:
612+
max_logging.log("matplotlib not available, skipping noise visualization")
613+
614+
# 2. Create current image visualization (VAE decoded)
615+
current_image_path = os.path.join(viz_dir, f"current_image_t{timestep}_frame{frame_idx}.png")
616+
617+
# Decode through VAE
618+
latents_for_decode = hidden_states[:, :, frame_idx:frame_idx+1, :, :] # Keep single frame with proper dims
619+
620+
# Apply VAE scaling
621+
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
622+
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
623+
scaled_latents = latents_for_decode / latents_std + latents_mean
624+
scaled_latents = scaled_latents.astype(jnp.float32)
625+
626+
# Decode through VAE
627+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
628+
video = self.vae.decode(scaled_latents, self.vae_cache)[0]
629+
630+
# Process video
631+
video = jnp.transpose(video, (0, 4, 1, 2, 3))
632+
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
633+
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
634+
video = self.video_processor.postprocess_video(video, output_type="np")
635+
636+
# Extract the decoded frame
637+
video_np = np.array(video)
638+
decoded_frame = video_np[0, 0] # (height, width, channels) - first batch, first frame
639+
640+
# Handle different channel formats for display
641+
if decoded_frame.shape[-1] == 3:
642+
# RGB
643+
display_frame = np.clip(decoded_frame, 0, 1)
644+
cmap = None
645+
else:
646+
# Convert to grayscale
647+
display_frame = decoded_frame.mean(axis=-1) if len(decoded_frame.shape) == 3 else decoded_frame
648+
cmap = 'gray'
649+
650+
# Create single image plot with fixed dimensions
651+
try:
652+
import matplotlib.pyplot as plt
653+
654+
# Use fixed figure size for consistent output dimensions
655+
fig, ax = plt.subplots(figsize=(10, 10))
656+
657+
ax.imshow(display_frame, cmap=cmap)
658+
ax.set_title(f"Decoded Frame {frame_idx} at t={timestep}")
659+
ax.axis('off')
660+
661+
# Save with fixed dimensions (no bbox_inches='tight' to avoid size variations)
662+
plt.savefig(current_image_path, dpi=150, bbox_inches=None,
663+
facecolor='white', edgecolor='none')
664+
plt.close()
665+
666+
max_logging.log(f"Saved current image to {current_image_path}")
667+
668+
except ImportError:
669+
max_logging.log("matplotlib not available for current image visualization")
670+
671+
max_logging.log("Frame visualization complete:")
672+
max_logging.log(f" - Noise representation: {noise_image_path}")
673+
max_logging.log(f" - Current decoded image: {current_image_path}")
674+
max_logging.log(f" - Raw latents: {latent_path}")
675+
676+
except Exception as e:
677+
max_logging.log(f"Error in frame visualization: {e}")
678+
679+
def _create_visualization_videos(self) -> None:
680+
"""
681+
Create denoising and noise evolution videos from the visualization output.
682+
Uses 4 fps (0.25 seconds per frame) by default.
683+
"""
684+
try:
685+
viz_dir = self._get_visualization_dir("frame_debug")
686+
if os.path.exists(viz_dir):
687+
max_logging.log("Creating visualization videos...")
688+
# Use 4 fps for 0.25 seconds per frame (default)
689+
create_visualization_videos(viz_dir, output_prefix="wan_visualization", fps=4.0)
690+
else:
691+
max_logging.log(f"Visualization directory not found: {viz_dir}")
692+
except Exception as e:
693+
max_logging.log(f"Error creating visualization videos: {e}")
694+
542695
def __call__(
543696
self,
544697
prompt: Union[str, List[str]] = None,
@@ -603,6 +756,11 @@ def __call__(
603756
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
604757
)
605758

759+
# Visualize initial noise if enabled - using actual first timestep
760+
if self._should_visualize("frame_debug"):
761+
initial_timestep = int(scheduler_state.timesteps[0])
762+
self.visualize_frame(latents, timestep=initial_timestep, frame_idx=0)
763+
606764
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
607765

608766
p_run_inference = partial(
@@ -612,6 +770,8 @@ def __call__(
612770
scheduler=self.scheduler,
613771
scheduler_state=scheduler_state,
614772
num_transformer_layers=self.transformer.config.num_layers,
773+
should_visualize=self._should_visualize("frame_debug"),
774+
visualize_fn=self.visualize_frame if self._should_visualize("frame_debug") else None,
615775
)
616776

617777
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -623,6 +783,13 @@ def __call__(
623783
prompt_embeds=prompt_embeds,
624784
negative_prompt_embeds=negative_prompt_embeds,
625785
)
786+
787+
# Visualize final denoised latents if enabled
788+
if self._should_visualize("frame_debug"):
789+
# Use the last (lowest) timestep from the scheduler
790+
final_timestep = int(scheduler_state.timesteps[-1])
791+
self.visualize_frame(latents, timestep=final_timestep, frame_idx=0)
792+
626793
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
627794
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
628795
latents = latents / latents_std + latents_mean
@@ -635,6 +802,11 @@ def __call__(
635802
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
636803
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
637804
video = self.video_processor.postprocess_video(video, output_type="np")
805+
806+
# Generate visualization videos if frame debugging is enabled
807+
if not vae_only and self._should_visualize("frame_debug"):
808+
self._create_visualization_videos()
809+
638810
return video
639811

640812

@@ -673,6 +845,8 @@ def run_inference(
673845
scheduler: FlaxUniPCMultistepScheduler,
674846
num_transformer_layers: int,
675847
scheduler_state,
848+
should_visualize: bool = False,
849+
visualize_fn = None,
676850
):
677851
do_classifier_free_guidance = guidance_scale > 1.0
678852
if do_classifier_free_guidance:
@@ -695,4 +869,8 @@ def run_inference(
695869
)
696870

697871
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
872+
# Visualize latents at each step if enabled
873+
if should_visualize and visualize_fn is not None:
874+
visualize_fn(latents, timestep=int(t), frame_idx=0)
875+
698876
return latents

0 commit comments

Comments
 (0)