Skip to content

Commit 6ec20ea

Browse files
authored
Rectified Flow Scheduler Test (#220)
* added rf scheduler test * rectified flow scheduler test added * removed safetensors downloading * replaced print statements with max_logging
1 parent de60c6c commit 6ec20ea

7 files changed

Lines changed: 84 additions & 0 deletions

File tree

128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import jax.numpy as jnp
2+
from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler
3+
import os
4+
from maxdiffusion import max_logging
5+
import torch
6+
import unittest
7+
from absl.testing import absltest
8+
import numpy as np
9+
10+
11+
12+
class rfTest(unittest.TestCase):
13+
14+
def test_rf_steps(self):
15+
# --- Simulation Parameters ---
16+
latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width)
17+
inference_steps_count = 5 # Number of steps for the denoising process
18+
19+
# --- Run the Simulation ---
20+
max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---")
21+
22+
seed = 42
23+
device = 'cpu'
24+
max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}")
25+
26+
generator = torch.Generator(device=device).manual_seed(seed)
27+
28+
# 1. Instantiate the scheduler
29+
config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'}
30+
flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config)
31+
32+
# 2. Create and set initial state for the scheduler
33+
flax_state = flax_scheduler.create_state()
34+
flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape)
35+
max_logging.log("\nScheduler initialized.")
36+
max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}")
37+
38+
# 3. Prepare the initial noisy latent sample
39+
# In a real scenario, this would typically be pure random noise (e.g., N(0,1))
40+
# For simulation, we'll generate it.
41+
42+
sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
43+
max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}")
44+
45+
# 4. Simulate the denoising loop
46+
max_logging.log("\nStarting denoising loop:")
47+
for i, t in enumerate(flax_state.timesteps):
48+
max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}")
49+
50+
# Simulate model_output (e.g., noise prediction from a UNet)
51+
model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
52+
53+
# Call the scheduler's step function
54+
scheduler_output = flax_scheduler.step(
55+
state=flax_state,
56+
model_output=model_output,
57+
timestep=t, # Pass the current timestep from the scheduler's sequence
58+
sample=sample,
59+
return_dict=True # Return a SchedulerOutput dataclass
60+
)
61+
62+
sample = scheduler_output.prev_sample # Update the sample for the next step
63+
flax_state = scheduler_output.state # Update the state for the next step
64+
65+
# Compare with pytorch implementation
66+
base_dir = os.path.dirname(__file__)
67+
ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref")
68+
ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy")
69+
if os.path.exists(ref_filename):
70+
pt_sample = np.load(ref_filename)
71+
torch.testing.assert_close(np.array(sample), pt_sample)
72+
else:
73+
max_logging.log(f"Warning: Reference file not found: {ref_filename}")
74+
75+
76+
max_logging.log("\nDenoising loop completed.")
77+
max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}")
78+
max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}")
79+
80+
max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.")
81+
82+
83+
if __name__ == "__main__":
84+
absltest.main()

0 commit comments

Comments
 (0)