Skip to content

Commit b84fc34

Browse files
implements skip layer guidance for better generations.
1 parent 9ee7fd3 commit b84fc34

4 files changed

Lines changed: 82 additions & 13 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,20 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful
207207
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
208208
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
209209
do_classifier_free_guidance: True
210-
height: 720
211-
width: 1280
210+
height: 480
211+
width: 832
212212
num_frames: 81
213213
guidance_scale: 5.0
214+
flow_shift: 3.0
215+
216+
# skip layer guidance
217+
slg_layers: [9]
218+
slg_start: 0.2
219+
slg_end: 1.0
214220
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
215221
guidance_rescale: 0.0
216222
num_inference_steps: 30
217223
save_final_checkpoint: False
218-
flow_shift: 5.0
219224

220225
# SDXL Lightning parameters
221226
lightning_from_pt: True

src/maxdiffusion/generate_wan.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@
2121
from maxdiffusion.utils import export_to_video
2222

2323
def run(config):
24+
print("seed: ", config.seed)
2425
pipeline = WanPipeline.from_pretrained(config)
2526
s0 = time.perf_counter()
27+
28+
# Skip layer guidance
29+
slg_layers = config.slg_layers
30+
slg_start = config.slg_start
31+
slg_end = config.slg_end
32+
2633
videos = pipeline(
2734
prompt=config.prompt,
2835
negative_prompt=config.negative_prompt,
@@ -31,6 +38,9 @@ def run(config):
3138
num_frames=config.num_frames,
3239
num_inference_steps=config.num_inference_steps,
3340
guidance_scale=config.guidance_scale,
41+
slg_layers=slg_layers,
42+
slg_start=slg_start,
43+
slg_end=slg_end
3444
)
3545

3646
print("compile time: ", (time.perf_counter() - s0))
@@ -46,6 +56,9 @@ def run(config):
4656
num_frames=config.num_frames,
4757
num_inference_steps=config.num_inference_steps,
4858
guidance_scale=config.guidance_scale,
59+
slg_layers=slg_layers,
60+
slg_start=slg_start,
61+
slg_end=slg_end
4962
)
5063
print("generation time: ", (time.perf_counter() - s0))
5164
for i in range(len(videos)):

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Tuple, Optional, Dict, Union, Any
17+
from typing import Tuple, Optional, Dict, Union, Any, List
1818
import math
1919
import jax
2020
import jax.numpy as jnp
@@ -453,6 +453,8 @@ def __call__(
453453
hidden_states: jax.Array,
454454
timestep: jax.Array,
455455
encoder_hidden_states: jax.Array,
456+
is_uncond: jax.Array, # jnp.bool_ scalar
457+
slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,)
456458
encoder_hidden_states_image: Optional[jax.Array] = None,
457459
return_dict: bool = True,
458460
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -476,8 +478,14 @@ def __call__(
476478

477479
if encoder_hidden_states_image is not None:
478480
raise NotImplementedError("img2vid is not yet implemented.")
479-
for block in self.blocks:
480-
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
481+
for block_idx, block in enumerate(self.blocks):
482+
should_skip_block = slg_mask[block_idx] & is_uncond
483+
hidden_states = jax.lax.cond(
484+
should_skip_block,
485+
lambda hs: hs, # If true, pass through original hidden_states (skip block)
486+
lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb),
487+
hidden_states
488+
)
481489
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
482490

483491
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def __call__(
369369
latents: jax.Array = None,
370370
prompt_embeds: jax.Array = None,
371371
negative_prompt_embeds: jax.Array = None,
372-
vae_only: bool = False
372+
vae_only: bool = False,
373+
slg_layers: List[int] = None,
374+
slg_start: float = 0.0,
375+
slg_end: float = 1.0
373376
):
374377
if not vae_only:
375378
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -424,7 +427,11 @@ def __call__(
424427
guidance_scale=guidance_scale,
425428
num_inference_steps=num_inference_steps,
426429
scheduler=self.scheduler,
427-
scheduler_state=scheduler_state
430+
scheduler_state=scheduler_state,
431+
slg_layers=slg_layers,
432+
slg_start=slg_start,
433+
slg_end=slg_end,
434+
num_transformer_layers=self.transformer.config.num_layers
428435
)
429436

430437
with self.mesh:
@@ -450,12 +457,22 @@ def __call__(
450457

451458

452459
@jax.jit
453-
def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds):
460+
def transformer_forward_pass(
461+
graphdef,
462+
sharded_state,
463+
rest_of_state,
464+
latents,
465+
timestep,
466+
prompt_embeds,
467+
is_uncond,
468+
slg_mask):
454469
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
455470
return wan_transformer(
456471
hidden_states=latents,
457472
timestep=timestep,
458-
encoder_hidden_states=prompt_embeds
473+
encoder_hidden_states=prompt_embeds,
474+
is_uncond=is_uncond,
475+
slg_mask=slg_mask
459476
)[0]
460477

461478
#@partial(jax.jit, static_argnums=(6, 7, 8))
@@ -469,16 +486,42 @@ def run_inference(
469486
guidance_scale: float,
470487
num_inference_steps: int,
471488
scheduler : FlaxUniPCMultistepScheduler,
472-
scheduler_state):
489+
num_transformer_layers: int,
490+
scheduler_state,
491+
slg_layers: List[int] = None,
492+
slg_start: float = 0.0,
493+
slg_end: float = 1.0
494+
):
473495
do_classifier_free_guidance = guidance_scale > 1.0
474496
for step in range(num_inference_steps):
497+
slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_)
498+
if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps):
499+
slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True)
475500
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
476501
timestep = jnp.broadcast_to(t, latents.shape[0])
477502

478-
noise_pred = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds)
503+
noise_pred = transformer_forward_pass(
504+
graphdef,
505+
sharded_state,
506+
rest_of_state,
507+
latents,
508+
timestep,
509+
prompt_embeds,
510+
is_uncond=jnp.array(False, dtype=jnp.bool_),
511+
slg_mask=slg_mask
512+
)
479513

480514
if do_classifier_free_guidance:
481-
noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds)
515+
noise_uncond = transformer_forward_pass(
516+
graphdef,
517+
sharded_state,
518+
rest_of_state,
519+
latents,
520+
timestep,
521+
negative_prompt_embeds,
522+
is_uncond=jnp.array(True, dtype=jnp.bool_),
523+
slg_mask=slg_mask
524+
)
482525
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
483526
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
484527
return latents

0 commit comments

Comments
 (0)