Skip to content

Commit b7c8ba6

Browse files
wan pipeline with generation. Correctness is still not verified.
1 parent 0731a49 commit b7c8ba6

2 files changed

Lines changed: 97 additions & 80 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414

1515
from typing import Sequence
16+
import time
1617
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
1718
from maxdiffusion import pyconfig
1819
from absl import app
20+
from maxdiffusion.utils import export_to_video
1921

2022
def run(config):
2123
pipeline = WanPipeline.from_pretrained(config)
2224

23-
pipeline(
25+
s0 = time.perf_counter()
26+
video = pipeline(
2427
prompt=config.prompt,
2528
negative_prompt=config.negative_prompt,
2629
height=config.height,
@@ -29,6 +32,20 @@ def run(config):
2932
num_inference_steps=config.num_inference_steps,
3033
guidance_scale=config.guidance_scale,
3134
)
35+
print("compile time: ", (time.perf_counter() - s0))
36+
s0 = time.perf_counter()
37+
video = pipeline(
38+
prompt=config.prompt,
39+
negative_prompt=config.negative_prompt,
40+
height=config.height,
41+
width=config.width,
42+
num_frames=config.num_frames,
43+
num_inference_steps=config.num_inference_steps,
44+
guidance_scale=config.guidance_scale,
45+
)
46+
print("generation time: ", (time.perf_counter() - s0))
47+
export_to_video(video[0], "jax_output.mp4", fps=16)
48+
3249

3350
def main(argv: Sequence[str]) -> None:
3451
pyconfig.initialize(argv)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ...models.wan.transformers.transformer_wan import WanModel
2727
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
2828
from maxdiffusion.video_processor import VideoProcessor
29-
from ...utils import export_to_video
3029
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3130
from transformers import AutoTokenizer, UMT5EncoderModel
3231
import ftfy
@@ -314,75 +313,77 @@ def __call__(
314313
max_sequence_length: int = 512,
315314
latents: jax.Array = None,
316315
prompt_embeds: jax.Array = None,
317-
negative_prompt_embeds: jax.Array = None
316+
negative_prompt_embeds: jax.Array = None,
317+
vae_only: bool = False
318318
):
319-
if num_frames % self.vae_scale_factor_temporal != 1:
320-
max_logging.log(
321-
f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
319+
if not vae_only:
320+
if num_frames % self.vae_scale_factor_temporal != 1:
321+
max_logging.log(
322+
f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
323+
)
324+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
325+
num_frames = max(num_frames, 1)
326+
327+
# 2. Define call parameters
328+
if prompt is not None and isinstance(prompt, str):
329+
batch_size = 1
330+
elif prompt is not None and isinstance(prompt, list):
331+
batch_size = len(prompt)
332+
333+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
334+
prompt=prompt,
335+
negative_prompt=negative_prompt,
336+
max_sequence_length=max_sequence_length,
337+
prompt_embeds=prompt_embeds,
338+
negative_prompt_embeds=negative_prompt_embeds
322339
)
323-
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
324-
num_frames = max(num_frames, 1)
325-
326-
# 2. Define call parameters
327-
if prompt is not None and isinstance(prompt, str):
328-
batch_size = 1
329-
elif prompt is not None and isinstance(prompt, list):
330-
batch_size = len(prompt)
331-
332-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
333-
prompt=prompt,
334-
negative_prompt=negative_prompt,
335-
max_sequence_length=max_sequence_length,
336-
prompt_embeds=prompt_embeds,
337-
negative_prompt_embeds=negative_prompt_embeds
338-
)
339340

340-
num_channel_latents = self.transformer.config.in_channels
341-
if latents is None:
342-
latents = self.prepare_latents(
343-
batch_size=batch_size,
344-
vae_scale_factor_temporal=self.vae_scale_factor_temporal,
345-
vae_scale_factor_spatial=self.vae_scale_factor_spatial,
346-
height=height,
347-
width=width,
348-
num_frames=num_frames,
349-
num_channels_latents=num_channel_latents
341+
num_channel_latents = self.transformer.config.in_channels
342+
if latents is None:
343+
latents = self.prepare_latents(
344+
batch_size=batch_size,
345+
vae_scale_factor_temporal=self.vae_scale_factor_temporal,
346+
vae_scale_factor_spatial=self.vae_scale_factor_spatial,
347+
height=height,
348+
width=width,
349+
num_frames=num_frames,
350+
num_channels_latents=num_channel_latents
351+
)
352+
353+
prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
354+
negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
355+
356+
latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate())
357+
prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate())
358+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate())
359+
360+
scheduler_state = self.scheduler.set_timesteps(
361+
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
350362
)
351363

352-
prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
353-
negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
354-
355-
latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate())
356-
prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate())
357-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate())
358-
359-
scheduler_state = self.scheduler.set_timesteps(
360-
self.scheduler_state, num_inference_steps=self.config.num_inference_steps, shape=latents.shape
361-
)
362-
363-
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
364+
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
364365

365-
p_run_inference = partial(
366-
run_inference,
367-
guidance_scale=self.config.guidance_scale,
368-
num_inference_steps=self.config.num_inference_steps,
369-
scheduler=self.scheduler,
370-
scheduler_state=scheduler_state
371-
)
372-
with self.mesh:
373-
latents = p_run_inference(
374-
graphdef=graphdef,
375-
sharded_state=state,
376-
rest_of_state=rest_of_state,
377-
latents=latents,
378-
prompt_embeds=prompt_embeds,
379-
negative_prompt_embeds=negative_prompt_embeds
366+
p_run_inference = partial(
367+
run_inference,
368+
guidance_scale=guidance_scale,
369+
num_inference_steps=num_inference_steps,
370+
scheduler=self.scheduler,
371+
scheduler_state=scheduler_state
380372
)
381-
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
382-
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
383-
latents = latents / latents_std + latents_mean
384-
385-
latents = latents.astype(self.config.weights_dtype)
373+
with self.mesh:
374+
latents = p_run_inference(
375+
graphdef=graphdef,
376+
sharded_state=state,
377+
rest_of_state=rest_of_state,
378+
latents=latents,
379+
prompt_embeds=prompt_embeds,
380+
negative_prompt_embeds=negative_prompt_embeds
381+
)
382+
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
383+
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
384+
latents = latents / latents_std + latents_mean
385+
386+
latents = latents.astype(self.config.weights_dtype)
386387

387388
jitted_decode = jax.jit(
388389
partial(
@@ -396,9 +397,18 @@ def __call__(
396397
video = jnp.transpose(video, (0, 4, 1, 2, 3))
397398
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
398399
video = self.video_processor.postprocess_video(video, output_type="np")
399-
export_to_video(video[0], "jax_output.mp4", fps=24)
400+
return video
401+
402+
403+
@jax.jit
404+
def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds):
405+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
406+
return wan_transformer(
407+
hidden_states=latents,
408+
timestep=timestep,
409+
encoder_hidden_states=prompt_embeds
410+
)[0]
400411

401-
402412
#@partial(jax.jit, static_argnums=(6, 7, 8))
403413
def run_inference(
404414
graphdef,
@@ -411,26 +421,16 @@ def run_inference(
411421
num_inference_steps: int,
412422
scheduler : FlaxUniPCMultistepScheduler,
413423
scheduler_state):
414-
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
415424
do_classifier_free_guidance = guidance_scale > 1.0
416425
for step in range(num_inference_steps):
417426
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
418427
timestep = jnp.broadcast_to(t, latents.shape[0])
419-
420-
noise_pred = wan_transformer(
421-
hidden_states=latents,
422-
timestep=timestep,
423-
encoder_hidden_states=prompt_embeds,
424-
return_dict=False
425-
)[0]
428+
429+
noise_pred = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds)
426430

427431
if do_classifier_free_guidance:
428-
noise_uncond = wan_transformer(
429-
hidden_states=latents,
430-
timestep=timestep,
431-
encoder_hidden_states=negative_prompt_embeds,
432-
return_dict=False
433-
)[0]
432+
noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds)
434433
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
434+
435435
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
436436
return latents

0 commit comments

Comments
 (0)