Skip to content

Commit 6973222

Browse files
initial wan pipeline for txt2vid. Not currently working.
1 parent 716598b commit 6973222

3 files changed

Lines changed: 170 additions & 4 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,14 @@ skip_first_n_steps_for_profiler: 5
210210
profiler_steps: 10
211211

212212
# Generation parameters
213-
prompt: "A magical castle in the middle of a forest, artistic drawing"
214-
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
215-
negative_prompt: "purple, red"
213+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
214+
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
215+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
216216
do_classifier_free_guidance: True
217-
guidance_scale: 3.5
217+
height: 720
218+
width: 1280
219+
num_frames: 81
220+
guidance_scale: 5.0
218221
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
219222
guidance_rescale: 0.0
220223
num_inference_steps: 30

src/maxdiffusion/generate_wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
def run(config):
2121
pipeline = WanPipeline.from_pretrained(config)
2222

23+
pipeline(
24+
prompt=config.prompt,
25+
negative_prompt=config.negative_prompt,
26+
height=config.height,
27+
width=config.width,
28+
num_frames=config.num_frames,
29+
num_inference_steps=config.num_inference_steps,
30+
guidance_scale=config.guidance_scale,
31+
)
32+
2333
def main(argv: Sequence[str]) -> None:
2434
pyconfig.initialize(argv)
2535
run(pyconfig.config)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414

1515
from typing import List, Union, Optional
16+
from functools import partial
1617
import numpy as np
1718
import jax
19+
import jax.numpy as jnp
1820
from jax.sharding import Mesh, PositionalSharding
1921
from flax import nnx
2022
from ...pyconfig import HyperParameters
23+
from ... import max_logging
2124
from ... import max_utils
2225
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
2326
from ...models.wan.transformers.transformer_wan import WanModel
@@ -219,6 +222,156 @@ def encode_prompt(
219222
num_videos_per_prompt: int = 1,
220223
max_sequence_length: int = 226,
221224
):
225+
prompt = [prompt] if isinstance(prompt, str) else prompt
226+
batch_size = len(prompt)
227+
prompt_embeds = self._get_t5_prompt_embeds(
228+
prompt=prompt,
229+
num_videos_per_prompt=num_videos_per_prompt,
230+
max_sequence_length=max_sequence_length,
231+
)
232+
233+
negative_prompt = negative_prompt or ""
234+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
235+
negative_prompt_embeds = self._get_t5_prompt_embeds(
236+
prompt=negative_prompt,
237+
num_videos_per_prompt=num_videos_per_prompt,
238+
max_sequence_length=max_sequence_length,
239+
)
240+
241+
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
242+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
243+
return prompt_embeds, negative_prompt_embeds
244+
245+
def prepare_latents(
246+
self,
247+
batch_size: int,
248+
vae_scale_factor_temporal: int,
249+
vae_scale_factor_spatial: int,
250+
height: int = 480,
251+
width: int = 832,
252+
num_frames: int = 81,
253+
num_channels_latents: int = 16,
254+
):
255+
rng = jax.random.key(self.config.seed)
256+
num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
257+
shape = (
258+
batch_size,
259+
num_latent_frames,
260+
int(height) // vae_scale_factor_spatial,
261+
int(width) // vae_scale_factor_spatial,
262+
num_channels_latents
263+
)
264+
latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype)
265+
266+
return latents
267+
268+
def __call__(
269+
self,
270+
prompt: Union[str, List[str]] = None,
271+
negative_prompt: Union[str, List[str]] = None,
272+
height: int = 480,
273+
width: int = 832,
274+
num_frames: int = 81,
275+
num_inference_steps: int = 50,
276+
guidance_scale: float = 5.0,
277+
num_videos_per_prompt: Optional[int] = 1,
278+
max_sequence_length: int = 512
279+
):
280+
if num_frames % self.vae_scale_factor_temporal != 1:
281+
max_logging.log(
282+
f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
283+
)
284+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
285+
num_frames = max(num_frames, 1)
222286

287+
# 2. Define call parameters
288+
if prompt is not None and isinstance(prompt, str):
289+
batch_size = 1
290+
elif prompt is not None and isinstance(prompt, list):
291+
batch_size = len(prompt)
292+
293+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
294+
prompt=prompt,
295+
negative_prompt=negative_prompt,
296+
max_sequence_length=max_sequence_length
297+
)
298+
299+
num_channel_latents = self.transformer.config.in_channels
300+
latents = self.prepare_latents(
301+
batch_size=batch_size,
302+
vae_scale_factor_temporal=self.vae_scale_factor_temporal,
303+
vae_scale_factor_spatial=self.vae_scale_factor_spatial,
304+
height=height,
305+
width=width,
306+
num_frames=num_frames,
307+
num_channels_latents=num_channel_latents
308+
)
223309

310+
prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
311+
negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
224312

313+
latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate())
314+
prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate())
315+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate())
316+
317+
scheduler_state = self.scheduler.set_timesteps(
318+
self.scheduler_state, num_inference_steps=self.config.num_inference_steps, shape=latents.shape
319+
)
320+
321+
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
322+
323+
p_run_inference = partial(
324+
run_inference,
325+
guidance_scale=self.config.guidance_scale,
326+
num_inference_steps=self.config.num_inference_steps,
327+
scheduler=self.scheduler,
328+
scheduler_state=scheduler_state
329+
)
330+
with self.mesh:
331+
latent = p_run_inference(
332+
graphdef=graphdef,
333+
sharded_state=state,
334+
rest_of_state=rest_of_state,
335+
latents=latents,
336+
prompt_embeds=prompt_embeds,
337+
negative_prompt_embeds=negative_prompt_embeds
338+
)
339+
340+
341+
@partial(jax.jit, static_argnums=(6, 7, 8))
342+
def run_inference(
343+
graphdef,
344+
sharded_state,
345+
rest_of_state,
346+
latents: jnp.array,
347+
prompt_embeds: jnp.array,
348+
negative_prompt_embeds: jnp.array,
349+
guidance_scale: float,
350+
num_inference_steps: int,
351+
scheduler : FlaxUniPCMultistepScheduler,
352+
scheduler_state):
353+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
354+
do_classifier_free_guidance = guidance_scale > 1.0
355+
356+
for step in range(num_inference_steps):
357+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
358+
timestep = jnp.broadcast_to(t, latents.shape[0])
359+
360+
noise_pred = wan_transformer(
361+
hidden_states=latents,
362+
timestep=timestep,
363+
encoder_hidden_states=prompt_embeds,
364+
return_dict=False
365+
)[0]
366+
367+
if do_classifier_free_guidance:
368+
noise_uncond = wan_transformer(
369+
hidden_states=latents,
370+
timestep=timestep,
371+
encoder_hidden_states=negative_prompt_embeds,
372+
return_dict=False
373+
)[0]
374+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
375+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
376+
377+
return latents

0 commit comments

Comments
 (0)