1313limitations under the License.
1414"""
1515
16+ import contextlib
1617import math
1718from typing import Any , Dict , Optional , Tuple
1819
2122import jax
2223from jax .ad_checkpoint import checkpoint_name
2324import jax .numpy as jnp
24- from jax .sharding import PartitionSpec
2525
2626from .... import common_types
2727from ....configuration_utils import register_to_config
@@ -62,8 +62,12 @@ def __init__(
6262 precision : jax .lax .Precision | None = None ,
6363 attention : str = "dot_product" ,
6464 dropout : float = 0.0 ,
65+ mask_padding_tokens : bool = True ,
66+ enable_jax_named_scopes : bool = False ,
6567 apply_input_projection : bool = False ,
6668 apply_output_projection : bool = False ,
69+ use_base2_exp : bool = False ,
70+ use_experimental_scheduler : bool = False ,
6771 ):
6872 """Sets up the model.
6973
@@ -90,7 +94,7 @@ def __init__(
9094 apply_output_projection: Whether to apply an output projection before
9195 outputting the result.
9296 """
93-
97+ self . enable_jax_named_scopes = enable_jax_named_scopes
9498 self .apply_input_projection = apply_input_projection
9599 self .apply_output_projection = apply_output_projection
96100
@@ -124,7 +128,12 @@ def __init__(
124128 precision = precision ,
125129 attention_kernel = attention ,
126130 dropout = dropout ,
131+ is_self_attention = True ,
132+ mask_padding_tokens = mask_padding_tokens ,
127133 residual_checkpoint_name = "self_attn" ,
134+ enable_jax_named_scopes = enable_jax_named_scopes ,
135+ use_base2_exp = use_base2_exp ,
136+ use_experimental_scheduler = use_experimental_scheduler ,
128137 )
129138
130139 # 3. Cross-attention
@@ -143,7 +152,12 @@ def __init__(
143152 precision = precision ,
144153 attention_kernel = attention ,
145154 dropout = dropout ,
155+ is_self_attention = False ,
156+ mask_padding_tokens = mask_padding_tokens ,
146157 residual_checkpoint_name = "cross_attn" ,
158+ enable_jax_named_scopes = enable_jax_named_scopes ,
159+ use_base2_exp = use_base2_exp ,
160+ use_experimental_scheduler = use_experimental_scheduler ,
147161 )
148162 assert cross_attn_norm is True , "cross_attn_norm must be True"
149163 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -158,6 +172,7 @@ def __init__(
158172 weights_dtype = weights_dtype ,
159173 precision = precision ,
160174 dropout = dropout ,
175+ enable_jax_named_scopes = enable_jax_named_scopes ,
161176 )
162177
163178 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
@@ -180,6 +195,10 @@ def __init__(
180195 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
181196 )
182197
198+ def conditional_named_scope (self , name : str ):
199+ """Return a JAX named scope if enabled, otherwise a null context."""
200+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
201+
183202 def __call__ (
184203 self ,
185204 * ,
@@ -191,65 +210,74 @@ def __call__(
191210 deterministic : bool = True ,
192211 rngs : nnx .Rngs | None = None ,
193212 ) -> Tuple [jax .Array , jax .Array ]:
194- if self .apply_input_projection :
195- control_hidden_states = self .proj_in (control_hidden_states )
196- control_hidden_states = control_hidden_states + hidden_states
197-
198- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
199- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
200- )
201-
202- control_hidden_states = jax .lax .with_sharding_constraint (
203- control_hidden_states ,
204- PartitionSpec ("data" , "fsdp" , "tensor" ),
205- )
206- control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
207- encoder_hidden_states = jax .lax .with_sharding_constraint (
208- encoder_hidden_states ,
209- PartitionSpec ("data" , "fsdp" , None ),
210- )
211-
212- # 1. Self-attention
213- with jax .named_scope ("attn1" ):
214- norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
215- control_hidden_states .dtype
213+ with self .conditional_named_scope ("vace_transformer_block" ):
214+ with self .conditional_named_scope ("input_projection" ):
215+ if self .apply_input_projection :
216+ control_hidden_states = self .proj_in (control_hidden_states )
217+ control_hidden_states = control_hidden_states + hidden_states
218+
219+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
220+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
216221 )
217- attn_output = self .attn1 (
218- hidden_states = norm_hidden_states ,
219- encoder_hidden_states = norm_hidden_states ,
220- rotary_emb = rotary_emb ,
221- deterministic = deterministic ,
222- rngs = rngs ,
223- )
224- control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
225- control_hidden_states .dtype
226- )
227-
228- # 2. Cross-attention
229- with jax .named_scope ("attn2" ):
230- norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
231- attn_output = self .attn2 (
232- hidden_states = norm_hidden_states ,
233- encoder_hidden_states = encoder_hidden_states ,
234- deterministic = deterministic ,
235- rngs = rngs ,
236- )
237- control_hidden_states = control_hidden_states + attn_output
238222
239- # 3. Feed-forward
240- with jax .named_scope ("ffn" ):
241- norm_hidden_states = (self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
242- control_hidden_states .dtype
243- )
244- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
245- control_hidden_states = (
246- control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
247- ).astype (control_hidden_states .dtype )
248- conditioning_states = None
249- if self .apply_output_projection :
250- conditioning_states = self .proj_out (control_hidden_states )
223+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_heads" ))
224+ control_hidden_states = jax .lax .with_sharding_constraint (control_hidden_states , axis_names )
225+ control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
226+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_kv" ))
227+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
228+
229+ # 1. Self-attention
230+ with self .conditional_named_scope ("self_attn" ):
231+ with self .conditional_named_scope ("self_attn_norm" ):
232+ norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
233+ control_hidden_states .dtype
234+ )
235+ with self .conditional_named_scope ("self_attn_attn" ):
236+ attn_output = self .attn1 (
237+ hidden_states = norm_hidden_states ,
238+ encoder_hidden_states = norm_hidden_states ,
239+ rotary_emb = rotary_emb ,
240+ deterministic = deterministic ,
241+ rngs = rngs ,
242+ )
243+ with self .conditional_named_scope ("self_attn_residual" ):
244+ control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
245+ control_hidden_states .dtype
246+ )
251247
252- return conditioning_states , control_hidden_states
248+ # 2. Cross-attention
249+ with self .conditional_named_scope ("cross_attn" ):
250+ with self .conditional_named_scope ("cross_attn_norm" ):
251+ norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
252+ with self .conditional_named_scope ("cross_attn_attn" ):
253+ attn_output = self .attn2 (
254+ hidden_states = norm_hidden_states ,
255+ encoder_hidden_states = encoder_hidden_states ,
256+ deterministic = deterministic ,
257+ rngs = rngs ,
258+ )
259+ with self .conditional_named_scope ("cross_attn_residual" ):
260+ control_hidden_states = control_hidden_states + attn_output
261+
262+ # 3. Feed-forward
263+ with self .conditional_named_scope ("mlp" ):
264+ with self .conditional_named_scope ("mlp_norm" ):
265+ norm_hidden_states = (
266+ self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa
267+ ).astype (control_hidden_states .dtype )
268+ with self .conditional_named_scope ("mlp_ffn" ):
269+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
270+ with self .conditional_named_scope ("mlp_residual" ):
271+ control_hidden_states = (
272+ control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
273+ ).astype (control_hidden_states .dtype )
274+
275+ with self .conditional_named_scope ("output_projection" ):
276+ conditioning_states = None
277+ if self .apply_output_projection :
278+ conditioning_states = self .proj_out (control_hidden_states )
279+
280+ return conditioning_states , control_hidden_states
253281
254282
255283class WanVACEModel (WanModel ):
@@ -289,7 +317,11 @@ def __init__(
289317 remat_policy : str = "None" ,
290318 names_which_can_be_saved : list [str ] = [],
291319 names_which_can_be_offloaded : list [str ] = [],
320+ mask_padding_tokens : bool = True ,
292321 scan_layers : bool = True ,
322+ enable_jax_named_scopes : bool = False ,
323+ use_base2_exp : bool = False ,
324+ use_experimental_scheduler : bool = False ,
293325 ):
294326 """Initializes the VACE model.
295327
@@ -302,6 +334,7 @@ def __init__(
302334 out_channels = out_channels or in_channels
303335 self .num_layers = num_layers
304336 self .scan_layers = scan_layers
337+ self .enable_jax_named_scopes = enable_jax_named_scopes
305338
306339 # 1. Patch & position embedding
307340 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -329,6 +362,7 @@ def __init__(
329362 text_embed_dim = text_dim ,
330363 image_embed_dim = image_dim ,
331364 pos_embed_seq_len = pos_embed_seq_len ,
365+ flash_min_seq_length = flash_min_seq_length ,
332366 )
333367
334368 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -358,6 +392,10 @@ def __init__(
358392 precision = precision ,
359393 attention = attention ,
360394 dropout = dropout ,
395+ mask_padding_tokens = mask_padding_tokens ,
396+ enable_jax_named_scopes = enable_jax_named_scopes ,
397+ use_base2_exp = use_base2_exp ,
398+ use_experimental_scheduler = use_experimental_scheduler ,
361399 )
362400 blocks .append (block )
363401 self .blocks = blocks
@@ -384,8 +422,12 @@ def __init__(
384422 precision = precision ,
385423 attention = attention ,
386424 dropout = dropout ,
425+ mask_padding_tokens = mask_padding_tokens ,
426+ enable_jax_named_scopes = enable_jax_named_scopes ,
387427 apply_input_projection = vace_block_id == 0 ,
388428 apply_output_projection = True ,
429+ use_base2_exp = use_base2_exp ,
430+ use_experimental_scheduler = use_experimental_scheduler ,
389431 )
390432 vace_blocks .append (vace_block )
391433 self .vace_blocks = vace_blocks
@@ -421,6 +463,10 @@ def __init__(
421463 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
422464 )
423465
466+ def conditional_named_scope (self , name : str ):
467+ """Return a JAX named scope if enabled, otherwise a null context."""
468+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
469+
424470 @jax .named_scope ("WanVACEModel" )
425471 def __call__ (
426472 self ,
@@ -436,7 +482,7 @@ def __call__(
436482 rngs : nnx .Rngs = None ,
437483 ) -> jax .Array :
438484 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
439- batch_size , num_channels , num_frames , height , width = hidden_states .shape
485+ batch_size , _ , num_frames , height , width = hidden_states .shape
440486 p_t , p_h , p_w = self .config .patch_size
441487 post_patch_num_frames = num_frames // p_t
442488 post_patch_height = height // p_h
@@ -453,13 +499,14 @@ def __call__(
453499
454500 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
455501 control_hidden_states = jnp .transpose (control_hidden_states , (0 , 2 , 3 , 4 , 1 ))
456- rotary_emb = self .rope (hidden_states )
457-
458- hidden_states = self .patch_embedding (hidden_states )
459- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
460-
461- control_hidden_states = self .vace_patch_embedding (control_hidden_states )
462- control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
502+ with self .conditional_named_scope ("rotary_embedding" ):
503+ rotary_emb = self .rope (hidden_states )
504+ with self .conditional_named_scope ("patch_embedding" ):
505+ hidden_states = self .patch_embedding (hidden_states )
506+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
507+
508+ control_hidden_states = self .vace_patch_embedding (control_hidden_states )
509+ control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
463510 control_hidden_states_padding = jnp .zeros ((
464511 batch_size ,
465512 control_hidden_states .shape [1 ],
@@ -469,16 +516,17 @@ def __call__(
469516 control_hidden_states = jnp .concatenate ([control_hidden_states , control_hidden_states_padding ], axis = 2 )
470517
471518 # Condition embedder is a FC layer.
472- (
473- temb ,
474- timestep_proj ,
475- encoder_hidden_states ,
476- encoder_hidden_states_image ,
477- _ ,
478- ) = self .condition_embedder ( # We will need to mask out the text embedding.
479- timestep , encoder_hidden_states , encoder_hidden_states_image
480- )
481- timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
519+ with self .conditional_named_scope ("condition_embedder" ):
520+ (
521+ temb ,
522+ timestep_proj ,
523+ encoder_hidden_states ,
524+ encoder_hidden_states_image ,
525+ _ ,
526+ ) = self .condition_embedder ( # We will need to mask out the text embedding.
527+ timestep , encoder_hidden_states , encoder_hidden_states_image
528+ )
529+ timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
482530
483531 if encoder_hidden_states_image is not None :
484532 raise NotImplementedError ("img2vid is not yet implemented." )
0 commit comments