1313limitations under the License.
1414"""
1515
16+ import contextlib
1617import math
1718from typing import Any , Dict , Optional , Tuple
1819
@@ -62,8 +63,12 @@ def __init__(
6263 precision : jax .lax .Precision | None = None ,
6364 attention : str = "dot_product" ,
6465 dropout : float = 0.0 ,
66+ mask_padding_tokens : bool = True ,
67+ enable_jax_named_scopes : bool = False ,
6568 apply_input_projection : bool = False ,
6669 apply_output_projection : bool = False ,
70+ use_base2_exp : bool = False ,
71+ use_experimental_scheduler : bool = False ,
6772 ):
6873 """Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095 apply_output_projection: Whether to apply an output projection before
9196 outputting the result.
9297 """
93-
98+ self . enable_jax_named_scopes = enable_jax_named_scopes
9499 self .apply_input_projection = apply_input_projection
95100 self .apply_output_projection = apply_output_projection
96101
@@ -124,7 +129,12 @@ def __init__(
124129 precision = precision ,
125130 attention_kernel = attention ,
126131 dropout = dropout ,
132+ is_self_attention = True ,
133+ mask_padding_tokens = mask_padding_tokens ,
127134 residual_checkpoint_name = "self_attn" ,
135+ enable_jax_named_scopes = enable_jax_named_scopes ,
136+ use_base2_exp = use_base2_exp ,
137+ use_experimental_scheduler = use_experimental_scheduler ,
128138 )
129139
130140 # 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153 precision = precision ,
144154 attention_kernel = attention ,
145155 dropout = dropout ,
156+ is_self_attention = False ,
157+ mask_padding_tokens = mask_padding_tokens ,
146158 residual_checkpoint_name = "cross_attn" ,
159+ enable_jax_named_scopes = enable_jax_named_scopes ,
160+ use_base2_exp = use_base2_exp ,
161+ use_experimental_scheduler = use_experimental_scheduler ,
147162 )
148163 assert cross_attn_norm is True , "cross_attn_norm must be True"
149164 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -158,6 +173,7 @@ def __init__(
158173 weights_dtype = weights_dtype ,
159174 precision = precision ,
160175 dropout = dropout ,
176+ enable_jax_named_scopes = enable_jax_named_scopes ,
161177 )
162178
163179 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
@@ -180,6 +196,10 @@ def __init__(
180196 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
181197 )
182198
199+ def conditional_named_scope (self , name : str ):
200+ """Return a JAX named scope if enabled, otherwise a null context."""
201+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
202+
183203 def __call__ (
184204 self ,
185205 * ,
@@ -191,65 +211,74 @@ def __call__(
191211 deterministic : bool = True ,
192212 rngs : nnx .Rngs | None = None ,
193213 ) -> Tuple [jax .Array , jax .Array ]:
194- if self .apply_input_projection :
195- control_hidden_states = self .proj_in (control_hidden_states )
196- control_hidden_states = control_hidden_states + hidden_states
197-
198- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
199- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
200- )
201-
202- control_hidden_states = jax .lax .with_sharding_constraint (
203- control_hidden_states ,
204- PartitionSpec ("data" , "fsdp" , "tensor" ),
205- )
206- control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
207- encoder_hidden_states = jax .lax .with_sharding_constraint (
208- encoder_hidden_states ,
209- PartitionSpec ("data" , "fsdp" , None ),
210- )
211-
212- # 1. Self-attention
213- with jax .named_scope ("attn1" ):
214- norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
215- control_hidden_states .dtype
214+ with self .conditional_named_scope ("vace_transformer_block" ):
215+ with self .conditional_named_scope ("input_projection" ):
216+ if self .apply_input_projection :
217+ control_hidden_states = self .proj_in (control_hidden_states )
218+ control_hidden_states = control_hidden_states + hidden_states
219+
220+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
221+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
216222 )
217- attn_output = self .attn1 (
218- hidden_states = norm_hidden_states ,
219- encoder_hidden_states = norm_hidden_states ,
220- rotary_emb = rotary_emb ,
221- deterministic = deterministic ,
222- rngs = rngs ,
223- )
224- control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
225- control_hidden_states .dtype
226- )
227-
228- # 2. Cross-attention
229- with jax .named_scope ("attn2" ):
230- norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
231- attn_output = self .attn2 (
232- hidden_states = norm_hidden_states ,
233- encoder_hidden_states = encoder_hidden_states ,
234- deterministic = deterministic ,
235- rngs = rngs ,
236- )
237- control_hidden_states = control_hidden_states + attn_output
238223
239- # 3. Feed-forward
240- with jax .named_scope ("ffn" ):
241- norm_hidden_states = (self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
242- control_hidden_states .dtype
243- )
244- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
245- control_hidden_states = (
246- control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
247- ).astype (control_hidden_states .dtype )
248- conditioning_states = None
249- if self .apply_output_projection :
250- conditioning_states = self .proj_out (control_hidden_states )
224+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_heads" ))
225+ control_hidden_states = jax .lax .with_sharding_constraint (control_hidden_states , axis_names )
226+ control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
227+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_kv" ))
228+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
229+
230+ # 1. Self-attention
231+ with self .conditional_named_scope ("self_attn" ):
232+ with self .conditional_named_scope ("self_attn_norm" ):
233+ norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
234+ control_hidden_states .dtype
235+ )
236+ with self .conditional_named_scope ("self_attn_attn" ):
237+ attn_output = self .attn1 (
238+ hidden_states = norm_hidden_states ,
239+ encoder_hidden_states = norm_hidden_states ,
240+ rotary_emb = rotary_emb ,
241+ deterministic = deterministic ,
242+ rngs = rngs ,
243+ )
244+ with self .conditional_named_scope ("self_attn_residual" ):
245+ control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
246+ control_hidden_states .dtype
247+ )
251248
252- return conditioning_states , control_hidden_states
249+ # 2. Cross-attention
250+ with self .conditional_named_scope ("cross_attn" ):
251+ with self .conditional_named_scope ("cross_attn_norm" ):
252+ norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
253+ with self .conditional_named_scope ("cross_attn_attn" ):
254+ attn_output = self .attn2 (
255+ hidden_states = norm_hidden_states ,
256+ encoder_hidden_states = encoder_hidden_states ,
257+ deterministic = deterministic ,
258+ rngs = rngs ,
259+ )
260+ with self .conditional_named_scope ("cross_attn_residual" ):
261+ control_hidden_states = control_hidden_states + attn_output
262+
263+ # 3. Feed-forward
264+ with self .conditional_named_scope ("mlp" ):
265+ with self .conditional_named_scope ("mlp_norm" ):
266+ norm_hidden_states = (
267+ self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa
268+ ).astype (control_hidden_states .dtype )
269+ with self .conditional_named_scope ("mlp_ffn" ):
270+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
271+ with self .conditional_named_scope ("mlp_residual" ):
272+ control_hidden_states = (
273+ control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
274+ ).astype (control_hidden_states .dtype )
275+
276+ with self .conditional_named_scope ("output_projection" ):
277+ conditioning_states = None
278+ if self .apply_output_projection :
279+ conditioning_states = self .proj_out (control_hidden_states )
280+
281+ return conditioning_states , control_hidden_states
253282
254283
255284class WanVACEModel (WanModel ):
@@ -289,7 +318,11 @@ def __init__(
289318 remat_policy : str = "None" ,
290319 names_which_can_be_saved : list [str ] = [],
291320 names_which_can_be_offloaded : list [str ] = [],
321+ mask_padding_tokens : bool = True ,
292322 scan_layers : bool = True ,
323+ enable_jax_named_scopes : bool = False ,
324+ use_base2_exp : bool = False ,
325+ use_experimental_scheduler : bool = False ,
293326 ):
294327 """Initializes the VACE model.
295328
@@ -302,6 +335,7 @@ def __init__(
302335 out_channels = out_channels or in_channels
303336 self .num_layers = num_layers
304337 self .scan_layers = scan_layers
338+ self .enable_jax_named_scopes = enable_jax_named_scopes
305339
306340 # 1. Patch & position embedding
307341 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -329,6 +363,7 @@ def __init__(
329363 text_embed_dim = text_dim ,
330364 image_embed_dim = image_dim ,
331365 pos_embed_seq_len = pos_embed_seq_len ,
366+ flash_min_seq_length = flash_min_seq_length ,
332367 )
333368
334369 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -358,6 +393,10 @@ def __init__(
358393 precision = precision ,
359394 attention = attention ,
360395 dropout = dropout ,
396+ mask_padding_tokens = mask_padding_tokens ,
397+ enable_jax_named_scopes = enable_jax_named_scopes ,
398+ use_base2_exp = use_base2_exp ,
399+ use_experimental_scheduler = use_experimental_scheduler ,
361400 )
362401 blocks .append (block )
363402 self .blocks = blocks
@@ -384,8 +423,12 @@ def __init__(
384423 precision = precision ,
385424 attention = attention ,
386425 dropout = dropout ,
426+ mask_padding_tokens = mask_padding_tokens ,
427+ enable_jax_named_scopes = enable_jax_named_scopes ,
387428 apply_input_projection = vace_block_id == 0 ,
388429 apply_output_projection = True ,
430+ use_base2_exp = use_base2_exp ,
431+ use_experimental_scheduler = use_experimental_scheduler ,
389432 )
390433 vace_blocks .append (vace_block )
391434 self .vace_blocks = vace_blocks
@@ -421,6 +464,10 @@ def __init__(
421464 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
422465 )
423466
467+ def conditional_named_scope (self , name : str ):
468+ """Return a JAX named scope if enabled, otherwise a null context."""
469+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
470+
424471 @jax .named_scope ("WanVACEModel" )
425472 def __call__ (
426473 self ,
@@ -436,7 +483,7 @@ def __call__(
436483 rngs : nnx .Rngs = None ,
437484 ) -> jax .Array :
438485 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
439- batch_size , num_channels , num_frames , height , width = hidden_states .shape
486+ batch_size , _ , num_frames , height , width = hidden_states .shape
440487 p_t , p_h , p_w = self .config .patch_size
441488 post_patch_num_frames = num_frames // p_t
442489 post_patch_height = height // p_h
@@ -453,13 +500,14 @@ def __call__(
453500
454501 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
455502 control_hidden_states = jnp .transpose (control_hidden_states , (0 , 2 , 3 , 4 , 1 ))
456- rotary_emb = self .rope (hidden_states )
457-
458- hidden_states = self .patch_embedding (hidden_states )
459- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
460-
461- control_hidden_states = self .vace_patch_embedding (control_hidden_states )
462- control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
503+ with self .conditional_named_scope ("rotary_embedding" ):
504+ rotary_emb = self .rope (hidden_states )
505+ with self .conditional_named_scope ("patch_embedding" ):
506+ hidden_states = self .patch_embedding (hidden_states )
507+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
508+
509+ control_hidden_states = self .vace_patch_embedding (control_hidden_states )
510+ control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
463511 control_hidden_states_padding = jnp .zeros ((
464512 batch_size ,
465513 control_hidden_states .shape [1 ],
@@ -469,16 +517,17 @@ def __call__(
469517 control_hidden_states = jnp .concatenate ([control_hidden_states , control_hidden_states_padding ], axis = 2 )
470518
471519 # Condition embedder is a FC layer.
472- (
473- temb ,
474- timestep_proj ,
475- encoder_hidden_states ,
476- encoder_hidden_states_image ,
477- _ ,
478- ) = self .condition_embedder ( # We will need to mask out the text embedding.
479- timestep , encoder_hidden_states , encoder_hidden_states_image
480- )
481- timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
520+ with self .conditional_named_scope ("condition_embedder" ):
521+ (
522+ temb ,
523+ timestep_proj ,
524+ encoder_hidden_states ,
525+ encoder_hidden_states_image ,
526+ _ ,
527+ ) = self .condition_embedder ( # We will need to mask out the text embedding.
528+ timestep , encoder_hidden_states , encoder_hidden_states_image
529+ )
530+ timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
482531
483532 if encoder_hidden_states_image is not None :
484533 raise NotImplementedError ("img2vid is not yet implemented." )
0 commit comments