Skip to content

Commit 3cdef23

Browse files
committed
Adds the VACE logic to WAN
1 parent 1ae2616 commit 3cdef23

1 file changed

Lines changed: 335 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 335 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,23 @@
1313
limitations under the License.
1414
"""
1515

16-
from typing import Tuple
16+
import math
17+
from typing import Any, Dict, Optional, Tuple
1718

1819
from flax import nnx
20+
import flax.linen as nn
1921
import jax
2022
from jax.ad_checkpoint import checkpoint_name
2123
import jax.numpy as jnp
2224
from jax.sharding import PartitionSpec
2325

2426
from .... import common_types
27+
from ....configuration_utils import register_to_config
2528
from ...attention_flax import FlaxWanAttention
29+
from ...gradient_checkpoint import GradientCheckpointType
2630
from ...normalization_flax import FP32LayerNorm
27-
from .transformer_wan import WanFeedForward
31+
from .transformer_wan import WanFeedForward, WanModel, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
32+
2833

2934
BlockSizes = common_types.BlockSizes
3035

@@ -268,3 +273,331 @@ def __call__(
268273
conditioning_states = self.proj_out(control_hidden_states)
269274

270275
return conditioning_states, control_hidden_states
276+
277+
278+
class WanVACEModel(WanModel):
279+
"""Extension of Wan to include VACE conditioning."""
280+
281+
@register_to_config
282+
def __init__(
283+
self,
284+
rngs: nnx.Rngs,
285+
vace_layers: list[int],
286+
vace_in_channels: int,
287+
model_type="t2v",
288+
patch_size: Tuple[int, ...] = (1, 2, 2),
289+
num_attention_heads: int = 40,
290+
attention_head_dim: int = 128,
291+
in_channels: int = 16,
292+
out_channels: int = 16,
293+
text_dim: int = 4096,
294+
freq_dim: int = 256,
295+
ffn_dim: int = 13824,
296+
num_layers: int = 40,
297+
dropout: float = 0.0,
298+
cross_attn_norm: bool = True,
299+
qk_norm: Optional[str] = "rms_norm_across_heads",
300+
eps: float = 1e-6,
301+
image_dim: Optional[int] = None,
302+
added_kv_proj_dim: Optional[int] = None,
303+
rope_max_seq_len: int = 1024,
304+
pos_embed_seq_len: Optional[int] = None,
305+
flash_min_seq_length: int = 4096,
306+
flash_block_sizes: BlockSizes = None,
307+
mesh: jax.sharding.Mesh = None,
308+
dtype: jnp.dtype = jnp.float32,
309+
weights_dtype: jnp.dtype = jnp.float32,
310+
precision: jax.lax.Precision = None,
311+
attention: str = "dot_product",
312+
remat_policy: str = "None",
313+
names_which_can_be_saved: list[str] = [],
314+
names_which_can_be_offloaded: list[str] = [],
315+
scan_layers: bool = True,
316+
):
317+
"""Initializes the VACE model.
318+
319+
All arguments are similar to WanModel with the exception of:
320+
vace_layers: Indices of the layers at which the VACE conditioning is
321+
injected.
322+
vace_in_channels: Number of channels in the VACE conditioning.
323+
"""
324+
inner_dim = num_attention_heads * attention_head_dim
325+
out_channels = out_channels or in_channels
326+
self.num_layers = num_layers
327+
self.scan_layers = scan_layers
328+
329+
# 1. Patch & position embedding
330+
self.rope = WanRotaryPosEmbed(
331+
attention_head_dim, patch_size, rope_max_seq_len
332+
)
333+
self.patch_embedding = nnx.Conv(
334+
in_channels,
335+
inner_dim,
336+
rngs=rngs,
337+
kernel_size=patch_size,
338+
strides=patch_size,
339+
dtype=dtype,
340+
param_dtype=weights_dtype,
341+
precision=precision,
342+
kernel_init=nnx.with_partitioning(
343+
nnx.initializers.xavier_uniform(),
344+
(None, None, None, None, "conv_out"),
345+
),
346+
)
347+
348+
# 2. Condition embeddings
349+
self.condition_embedder = WanTimeTextImageEmbedding(
350+
rngs=rngs,
351+
dim=inner_dim,
352+
time_freq_dim=freq_dim,
353+
time_proj_dim=inner_dim * 6,
354+
text_embed_dim=text_dim,
355+
image_embed_dim=image_dim,
356+
pos_embed_seq_len=pos_embed_seq_len,
357+
)
358+
359+
self.gradient_checkpoint = GradientCheckpointType.from_str(
360+
remat_policy
361+
)
362+
self.names_which_can_be_offloaded = names_which_can_be_offloaded
363+
self.names_which_can_be_saved = names_which_can_be_saved
364+
365+
# 3. Transformer blocks
366+
367+
if scan_layers:
368+
raise NotImplementedError("scan_layers is not supported yet")
369+
else:
370+
blocks = nnx.List([])
371+
for _ in range(num_layers):
372+
block = WanTransformerBlock(
373+
rngs=rngs,
374+
dim=inner_dim,
375+
ffn_dim=ffn_dim,
376+
num_heads=num_attention_heads,
377+
qk_norm=qk_norm,
378+
cross_attn_norm=cross_attn_norm,
379+
eps=eps,
380+
flash_min_seq_length=flash_min_seq_length,
381+
flash_block_sizes=flash_block_sizes,
382+
mesh=mesh,
383+
dtype=dtype,
384+
weights_dtype=weights_dtype,
385+
precision=precision,
386+
attention=attention,
387+
dropout=dropout,
388+
)
389+
blocks.append(block)
390+
self.blocks = blocks
391+
392+
if scan_layers:
393+
raise NotImplementedError("scan_layers is not supported yet")
394+
else:
395+
vace_blocks = nnx.List([])
396+
397+
for vace_block_id in self.config.vace_layers:
398+
vace_block = WanVACETransformerBlock(
399+
rngs=rngs,
400+
dim=inner_dim,
401+
ffn_dim=ffn_dim,
402+
num_heads=num_attention_heads,
403+
qk_norm=qk_norm,
404+
cross_attn_norm=cross_attn_norm,
405+
eps=eps,
406+
flash_min_seq_length=flash_min_seq_length,
407+
flash_block_sizes=flash_block_sizes,
408+
mesh=mesh,
409+
dtype=dtype,
410+
weights_dtype=weights_dtype,
411+
precision=precision,
412+
attention=attention,
413+
dropout=dropout,
414+
apply_input_projection=vace_block_id == 0,
415+
apply_output_projection=True,
416+
)
417+
vace_blocks.append(vace_block)
418+
self.vace_blocks = vace_blocks
419+
420+
self.vace_patch_embedding = nnx.Conv(
421+
rngs=rngs,
422+
in_features=vace_in_channels,
423+
out_features=inner_dim,
424+
kernel_size=patch_size,
425+
strides=patch_size,
426+
dtype=dtype,
427+
param_dtype=weights_dtype,
428+
precision=precision,
429+
kernel_init=nnx.with_partitioning(
430+
nnx.initializers.xavier_uniform(),
431+
(None, None, None, None, "conv_out"),
432+
),
433+
)
434+
435+
self.norm_out = FP32LayerNorm(
436+
rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False
437+
)
438+
self.proj_out = nnx.Linear(
439+
rngs=rngs,
440+
in_features=inner_dim,
441+
out_features=out_channels * math.prod(patch_size),
442+
dtype=dtype,
443+
param_dtype=weights_dtype,
444+
precision=precision,
445+
kernel_init=nnx.with_partitioning(
446+
nnx.initializers.xavier_uniform(), ("embed", None)
447+
),
448+
)
449+
key = rngs.params()
450+
self.scale_shift_table = nnx.Param(
451+
jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5,
452+
kernel_init=nnx.with_partitioning(
453+
nnx.initializers.xavier_uniform(), (None, None, "embed")
454+
),
455+
)
456+
457+
@jax.named_scope("WanVACEModel")
458+
def __call__(
459+
self,
460+
hidden_states: jax.Array,
461+
timestep: jax.Array,
462+
encoder_hidden_states: jax.Array,
463+
control_hidden_states: jax.Array,
464+
control_hidden_states_scale: Optional[jax.Array] = None,
465+
encoder_hidden_states_image: Optional[jax.Array] = None,
466+
return_dict: bool = True,
467+
attention_kwargs: Optional[Dict[str, Any]] = None,
468+
deterministic: bool = True,
469+
rngs: nnx.Rngs = None,
470+
) -> jax.Array:
471+
hidden_states = nn.with_logical_constraint(
472+
hidden_states, ("batch", None, None, None, None)
473+
)
474+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
475+
p_t, p_h, p_w = self.config.patch_size
476+
post_patch_num_frames = num_frames // p_t
477+
post_patch_height = height // p_h
478+
post_patch_width = width // p_w
479+
480+
if control_hidden_states_scale is None:
481+
control_hidden_states_scale = jnp.ones_like(
482+
control_hidden_states, shape=(len(self.config.vace_layers),)
483+
)
484+
if control_hidden_states_scale.shape[0] != len(self.config.vace_layers):
485+
raise ValueError(
486+
"Length of `control_hidden_states_scale`"
487+
f" {len(control_hidden_states_scale)} should be equal to"
488+
f" {len(self.config.vace_layers)}."
489+
)
490+
491+
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
492+
control_hidden_states = jnp.transpose(
493+
control_hidden_states, (0, 2, 3, 4, 1)
494+
)
495+
rotary_emb = self.rope(hidden_states)
496+
497+
hidden_states = self.patch_embedding(hidden_states)
498+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
499+
500+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
501+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
502+
control_hidden_states_padding = jnp.zeros((
503+
batch_size,
504+
control_hidden_states.shape[1],
505+
hidden_states.shape[2] - control_hidden_states.shape[2],
506+
))
507+
508+
control_hidden_states = jnp.concatenate(
509+
[control_hidden_states, control_hidden_states_padding], axis=2
510+
)
511+
512+
# Condition embedder is a FC layer.
513+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
514+
self.condition_embedder( # We will need to mask out the text embedding.
515+
timestep, encoder_hidden_states, encoder_hidden_states_image
516+
)
517+
)
518+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
519+
520+
if encoder_hidden_states_image is not None:
521+
raise NotImplementedError("img2vid is not yet implemented.")
522+
523+
if self.scan_layers:
524+
raise NotImplementedError("scan_layers is not supported yet")
525+
else:
526+
# Prepare VACE hints
527+
control_hidden_states_list = nnx.List([])
528+
for i, vace_block in enumerate(self.vace_blocks):
529+
def layer_forward(hidden_states, control_hidden_states):
530+
return vace_block(
531+
hidden_states=hidden_states,
532+
encoder_hidden_states=encoder_hidden_states,
533+
control_hidden_states=control_hidden_states,
534+
temb=timestep_proj,
535+
rotary_emb=rotary_emb,
536+
deterministic=deterministic,
537+
rngs=rngs,
538+
)
539+
540+
rematted_layer_forward = self.gradient_checkpoint.apply(
541+
layer_forward,
542+
self.names_which_can_be_saved,
543+
self.names_which_can_be_offloaded,
544+
prevent_cse=not self.scan_layers,
545+
)
546+
conditioning_states, control_hidden_states = rematted_layer_forward(
547+
hidden_states, control_hidden_states
548+
)
549+
control_hidden_states_list.append(
550+
(conditioning_states, control_hidden_states_scale[i])
551+
)
552+
553+
control_hidden_states_list = control_hidden_states_list[::-1]
554+
555+
for i, block in enumerate(self.blocks):
556+
557+
def layer_forward_vace(hidden_states):
558+
return block(
559+
hidden_states,
560+
encoder_hidden_states,
561+
timestep_proj,
562+
rotary_emb,
563+
deterministic,
564+
rngs,
565+
)
566+
567+
rematted_layer_forward = self.gradient_checkpoint.apply(
568+
layer_forward_vace,
569+
self.names_which_can_be_saved,
570+
self.names_which_can_be_offloaded,
571+
prevent_cse=not self.scan_layers,
572+
)
573+
hidden_states = rematted_layer_forward(hidden_states)
574+
if i in self.config.vace_layers:
575+
control_hint, scale = control_hidden_states_list.pop()
576+
hidden_states = hidden_states + control_hint * scale
577+
578+
# 6. Output norm, projection & unpatchify
579+
shift, scale = jnp.split(
580+
self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1
581+
)
582+
583+
hidden_states = (
584+
self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift
585+
).astype(hidden_states.dtype)
586+
with jax.named_scope("proj_out"):
587+
hidden_states = self.proj_out(hidden_states) # Linear layer.
588+
589+
hidden_states = hidden_states.reshape(
590+
batch_size,
591+
post_patch_num_frames,
592+
post_patch_height,
593+
post_patch_width,
594+
p_t,
595+
p_h,
596+
p_w,
597+
-1,
598+
)
599+
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
600+
hidden_states = jax.lax.collapse(hidden_states, 6, None)
601+
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
602+
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
603+
return hidden_states

0 commit comments

Comments
 (0)