Skip to content

Commit 46d51a8

Browse files
committed
Adds the VACE logic to WAN
1 parent 1ae2616 commit 46d51a8

1 file changed

Lines changed: 336 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 336 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@
1515

1616
from typing import Tuple
1717

18+
import math
19+
from typing import Any, Dict, Optional, Tuple
20+
1821
from flax import nnx
22+
import flax.linen as nn
1923
import jax
2024
from jax.ad_checkpoint import checkpoint_name
2125
import jax.numpy as jnp
2226
from jax.sharding import PartitionSpec
2327

2428
from .... import common_types
29+
from ....configuration_utils import register_to_config
2530
from ...attention_flax import FlaxWanAttention
31+
from ...gradient_checkpoint import GradientCheckpointType
2632
from ...normalization_flax import FP32LayerNorm
27-
from .transformer_wan import WanFeedForward
33+
from .transformer_wan import WanFeedForward, WanModel, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
34+
2835

2936
BlockSizes = common_types.BlockSizes
3037

@@ -268,3 +275,331 @@ def __call__(
268275
conditioning_states = self.proj_out(control_hidden_states)
269276

270277
return conditioning_states, control_hidden_states
278+
279+
280+
class WanVACEModel(WanModel):
281+
"""Extension of Wan to include VACE conditioning."""
282+
283+
@register_to_config
284+
def __init__(
285+
self,
286+
rngs: nnx.Rngs,
287+
vace_layers: list[int],
288+
vace_in_channels: int,
289+
model_type="t2v",
290+
patch_size: Tuple[int, ...] = (1, 2, 2),
291+
num_attention_heads: int = 40,
292+
attention_head_dim: int = 128,
293+
in_channels: int = 16,
294+
out_channels: int = 16,
295+
text_dim: int = 4096,
296+
freq_dim: int = 256,
297+
ffn_dim: int = 13824,
298+
num_layers: int = 40,
299+
dropout: float = 0.0,
300+
cross_attn_norm: bool = True,
301+
qk_norm: Optional[str] = "rms_norm_across_heads",
302+
eps: float = 1e-6,
303+
image_dim: Optional[int] = None,
304+
added_kv_proj_dim: Optional[int] = None,
305+
rope_max_seq_len: int = 1024,
306+
pos_embed_seq_len: Optional[int] = None,
307+
flash_min_seq_length: int = 4096,
308+
flash_block_sizes: BlockSizes = None,
309+
mesh: jax.sharding.Mesh = None,
310+
dtype: jnp.dtype = jnp.float32,
311+
weights_dtype: jnp.dtype = jnp.float32,
312+
precision: jax.lax.Precision = None,
313+
attention: str = "dot_product",
314+
remat_policy: str = "None",
315+
names_which_can_be_saved: list[str] = [],
316+
names_which_can_be_offloaded: list[str] = [],
317+
scan_layers: bool = True,
318+
):
319+
"""Initializes the VACE model.
320+
321+
All arguments are similar to WanModel with the exception of:
322+
vace_layers: Indices of the layers at which the VACE conditioning is
323+
injected.
324+
vace_in_channels: Number of channels in the VACE conditioning.
325+
"""
326+
inner_dim = num_attention_heads * attention_head_dim
327+
out_channels = out_channels or in_channels
328+
self.num_layers = num_layers
329+
self.scan_layers = scan_layers
330+
331+
# 1. Patch & position embedding
332+
self.rope = WanRotaryPosEmbed(
333+
attention_head_dim, patch_size, rope_max_seq_len
334+
)
335+
self.patch_embedding = nnx.Conv(
336+
in_channels,
337+
inner_dim,
338+
rngs=rngs,
339+
kernel_size=patch_size,
340+
strides=patch_size,
341+
dtype=dtype,
342+
param_dtype=weights_dtype,
343+
precision=precision,
344+
kernel_init=nnx.with_partitioning(
345+
nnx.initializers.xavier_uniform(),
346+
(None, None, None, None, "conv_out"),
347+
),
348+
)
349+
350+
# 2. Condition embeddings
351+
self.condition_embedder = WanTimeTextImageEmbedding(
352+
rngs=rngs,
353+
dim=inner_dim,
354+
time_freq_dim=freq_dim,
355+
time_proj_dim=inner_dim * 6,
356+
text_embed_dim=text_dim,
357+
image_embed_dim=image_dim,
358+
pos_embed_seq_len=pos_embed_seq_len,
359+
)
360+
361+
self.gradient_checkpoint = GradientCheckpointType.from_str(
362+
remat_policy
363+
)
364+
self.names_which_can_be_offloaded = names_which_can_be_offloaded
365+
self.names_which_can_be_saved = names_which_can_be_saved
366+
367+
# 3. Transformer blocks
368+
369+
if scan_layers:
370+
raise NotImplementedError("scan_layers is not supported yet")
371+
else:
372+
blocks = nnx.List([])
373+
for _ in range(num_layers):
374+
block = WanTransformerBlock(
375+
rngs=rngs,
376+
dim=inner_dim,
377+
ffn_dim=ffn_dim,
378+
num_heads=num_attention_heads,
379+
qk_norm=qk_norm,
380+
cross_attn_norm=cross_attn_norm,
381+
eps=eps,
382+
flash_min_seq_length=flash_min_seq_length,
383+
flash_block_sizes=flash_block_sizes,
384+
mesh=mesh,
385+
dtype=dtype,
386+
weights_dtype=weights_dtype,
387+
precision=precision,
388+
attention=attention,
389+
dropout=dropout,
390+
)
391+
blocks.append(block)
392+
self.blocks = blocks
393+
394+
if scan_layers:
395+
raise NotImplementedError("scan_layers is not supported yet")
396+
else:
397+
vace_blocks = nnx.List([])
398+
399+
for vace_block_id in self.config.vace_layers:
400+
vace_block = WanVACETransformerBlock(
401+
rngs=rngs,
402+
dim=inner_dim,
403+
ffn_dim=ffn_dim,
404+
num_heads=num_attention_heads,
405+
qk_norm=qk_norm,
406+
cross_attn_norm=cross_attn_norm,
407+
eps=eps,
408+
flash_min_seq_length=flash_min_seq_length,
409+
flash_block_sizes=flash_block_sizes,
410+
mesh=mesh,
411+
dtype=dtype,
412+
weights_dtype=weights_dtype,
413+
precision=precision,
414+
attention=attention,
415+
dropout=dropout,
416+
apply_input_projection=vace_block_id == 0,
417+
apply_output_projection=True,
418+
)
419+
vace_blocks.append(vace_block)
420+
self.vace_blocks = vace_blocks
421+
422+
self.vace_patch_embedding = nnx.Conv(
423+
rngs=rngs,
424+
in_features=vace_in_channels,
425+
out_features=inner_dim,
426+
kernel_size=patch_size,
427+
strides=patch_size,
428+
dtype=dtype,
429+
param_dtype=weights_dtype,
430+
precision=precision,
431+
kernel_init=nnx.with_partitioning(
432+
nnx.initializers.xavier_uniform(),
433+
(None, None, None, None, "conv_out"),
434+
),
435+
)
436+
437+
self.norm_out = FP32LayerNorm(
438+
rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False
439+
)
440+
self.proj_out = nnx.Linear(
441+
rngs=rngs,
442+
in_features=inner_dim,
443+
out_features=out_channels * math.prod(patch_size),
444+
dtype=dtype,
445+
param_dtype=weights_dtype,
446+
precision=precision,
447+
kernel_init=nnx.with_partitioning(
448+
nnx.initializers.xavier_uniform(), ("embed", None)
449+
),
450+
)
451+
key = rngs.params()
452+
self.scale_shift_table = nnx.Param(
453+
jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5,
454+
kernel_init=nnx.with_partitioning(
455+
nnx.initializers.xavier_uniform(), (None, None, "embed")
456+
),
457+
)
458+
459+
@jax.named_scope("WanVACEModel")
460+
def __call__(
461+
self,
462+
hidden_states: jax.Array,
463+
timestep: jax.Array,
464+
encoder_hidden_states: jax.Array,
465+
control_hidden_states: jax.Array,
466+
control_hidden_states_scale: Optional[jax.Array] = None,
467+
encoder_hidden_states_image: Optional[jax.Array] = None,
468+
return_dict: bool = True,
469+
attention_kwargs: Optional[Dict[str, Any]] = None,
470+
deterministic: bool = True,
471+
rngs: nnx.Rngs = None,
472+
) -> jax.Array:
473+
hidden_states = nn.with_logical_constraint(
474+
hidden_states, ("batch", None, None, None, None)
475+
)
476+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
477+
p_t, p_h, p_w = self.config.patch_size
478+
post_patch_num_frames = num_frames // p_t
479+
post_patch_height = height // p_h
480+
post_patch_width = width // p_w
481+
482+
if control_hidden_states_scale is None:
483+
control_hidden_states_scale = jnp.ones_like(
484+
control_hidden_states, shape=(len(self.config.vace_layers),)
485+
)
486+
if control_hidden_states_scale.shape[0] != len(self.config.vace_layers):
487+
raise ValueError(
488+
"Length of `control_hidden_states_scale`"
489+
f" {len(control_hidden_states_scale)} should be equal to"
490+
f" {len(self.config.vace_layers)}."
491+
)
492+
493+
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
494+
control_hidden_states = jnp.transpose(
495+
control_hidden_states, (0, 2, 3, 4, 1)
496+
)
497+
rotary_emb = self.rope(hidden_states)
498+
499+
hidden_states = self.patch_embedding(hidden_states)
500+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
501+
502+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
503+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
504+
control_hidden_states_padding = jnp.zeros((
505+
batch_size,
506+
control_hidden_states.shape[1],
507+
hidden_states.shape[2] - control_hidden_states.shape[2],
508+
))
509+
510+
control_hidden_states = jnp.concatenate(
511+
[control_hidden_states, control_hidden_states_padding], axis=2
512+
)
513+
514+
# Condition embedder is a FC layer.
515+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
516+
self.condition_embedder( # We will need to mask out the text embedding.
517+
timestep, encoder_hidden_states, encoder_hidden_states_image
518+
)
519+
)
520+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
521+
522+
if encoder_hidden_states_image is not None:
523+
raise NotImplementedError("img2vid is not yet implemented.")
524+
525+
if self.scan_layers:
526+
raise NotImplementedError("scan_layers is not supported yet")
527+
else:
528+
# Prepare VACE hints
529+
control_hidden_states_list = nnx.List([])
530+
for i, vace_block in enumerate(self.vace_blocks):
531+
def layer_forward(hidden_states, control_hidden_states):
532+
return vace_block(
533+
hidden_states=hidden_states,
534+
encoder_hidden_states=encoder_hidden_states,
535+
control_hidden_states=control_hidden_states,
536+
temb=timestep_proj,
537+
rotary_emb=rotary_emb,
538+
deterministic=deterministic,
539+
rngs=rngs,
540+
)
541+
542+
rematted_layer_forward = self.gradient_checkpoint.apply(
543+
layer_forward,
544+
self.names_which_can_be_saved,
545+
self.names_which_can_be_offloaded,
546+
prevent_cse=not self.scan_layers,
547+
)
548+
conditioning_states, control_hidden_states = rematted_layer_forward(
549+
hidden_states, control_hidden_states
550+
)
551+
control_hidden_states_list.append(
552+
(conditioning_states, control_hidden_states_scale[i])
553+
)
554+
555+
control_hidden_states_list = control_hidden_states_list[::-1]
556+
557+
for i, block in enumerate(self.blocks):
558+
559+
def layer_forward_vace(hidden_states):
560+
return block(
561+
hidden_states,
562+
encoder_hidden_states,
563+
timestep_proj,
564+
rotary_emb,
565+
deterministic,
566+
rngs,
567+
)
568+
569+
rematted_layer_forward = self.gradient_checkpoint.apply(
570+
layer_forward_vace,
571+
self.names_which_can_be_saved,
572+
self.names_which_can_be_offloaded,
573+
prevent_cse=not self.scan_layers,
574+
)
575+
hidden_states = rematted_layer_forward(hidden_states)
576+
if i in self.config.vace_layers:
577+
control_hint, scale = control_hidden_states_list.pop()
578+
hidden_states = hidden_states + control_hint * scale
579+
580+
# 6. Output norm, projection & unpatchify
581+
shift, scale = jnp.split(
582+
self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1
583+
)
584+
585+
hidden_states = (
586+
self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift
587+
).astype(hidden_states.dtype)
588+
with jax.named_scope("proj_out"):
589+
hidden_states = self.proj_out(hidden_states) # Linear layer.
590+
591+
hidden_states = hidden_states.reshape(
592+
batch_size,
593+
post_patch_num_frames,
594+
post_patch_height,
595+
post_patch_width,
596+
p_t,
597+
p_h,
598+
p_w,
599+
-1,
600+
)
601+
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
602+
hidden_states = jax.lax.collapse(hidden_states, 6, None)
603+
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
604+
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
605+
return hidden_states

0 commit comments

Comments
 (0)