1313limitations under the License.
1414"""
1515
16+ import contextlib
1617import math
1718from typing import Any , Dict , Optional , Tuple
1819
@@ -62,8 +63,12 @@ def __init__(
6263 precision : jax .lax .Precision | None = None ,
6364 attention : str = "dot_product" ,
6465 dropout : float = 0.0 ,
66+ mask_padding_tokens : bool = True ,
67+ enable_jax_named_scopes : bool = False ,
6568 apply_input_projection : bool = False ,
6669 apply_output_projection : bool = False ,
70+ use_base2_exp : bool = False ,
71+ use_experimental_scheduler : bool = False ,
6772 ):
6873 """Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095 apply_output_projection: Whether to apply an output projection before
9196 outputting the result.
9297 """
93-
98+ self . enable_jax_named_scopes = enable_jax_named_scopes
9499 self .apply_input_projection = apply_input_projection
95100 self .apply_output_projection = apply_output_projection
96101
@@ -124,7 +129,12 @@ def __init__(
124129 precision = precision ,
125130 attention_kernel = attention ,
126131 dropout = dropout ,
132+ is_self_attention = True ,
133+ mask_padding_tokens = mask_padding_tokens ,
127134 residual_checkpoint_name = "self_attn" ,
135+ enable_jax_named_scopes = enable_jax_named_scopes ,
136+ use_base2_exp = use_base2_exp ,
137+ use_experimental_scheduler = use_experimental_scheduler ,
128138 )
129139
130140 # 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153 precision = precision ,
144154 attention_kernel = attention ,
145155 dropout = dropout ,
156+ is_self_attention = False ,
157+ mask_padding_tokens = mask_padding_tokens ,
146158 residual_checkpoint_name = "cross_attn" ,
159+ enable_jax_named_scopes = enable_jax_named_scopes ,
160+ use_base2_exp = use_base2_exp ,
161+ use_experimental_scheduler = use_experimental_scheduler ,
147162 )
148163 assert cross_attn_norm is True , "cross_attn_norm must be True"
149164 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -158,6 +173,7 @@ def __init__(
158173 weights_dtype = weights_dtype ,
159174 precision = precision ,
160175 dropout = dropout ,
176+ enable_jax_named_scopes = enable_jax_named_scopes ,
161177 )
162178
163179 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
@@ -180,6 +196,10 @@ def __init__(
180196 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
181197 )
182198
199+ def conditional_named_scope (self , name : str ):
200+ """Return a JAX named scope if enabled, otherwise a null context."""
201+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
202+
183203 def __call__ (
184204 self ,
185205 * ,
@@ -191,65 +211,75 @@ def __call__(
191211 deterministic : bool = True ,
192212 rngs : nnx .Rngs | None = None ,
193213 ) -> Tuple [jax .Array , jax .Array ]:
194- if self .apply_input_projection :
195- control_hidden_states = self .proj_in (control_hidden_states )
196- control_hidden_states = control_hidden_states + hidden_states
197-
198- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
199- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
200- )
201214
202- control_hidden_states = jax .lax .with_sharding_constraint (
203- control_hidden_states ,
204- PartitionSpec ("data" , "fsdp" , "tensor" ),
205- )
206- control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
207- encoder_hidden_states = jax .lax .with_sharding_constraint (
208- encoder_hidden_states ,
209- PartitionSpec ("data" , "fsdp" , None ),
210- )
215+ with self .conditional_named_scope ("vace_transformer_block" ):
216+ with self .conditional_named_scope ("input_projection" ):
217+ if self .apply_input_projection :
218+ control_hidden_states = self .proj_in (control_hidden_states )
219+ control_hidden_states = control_hidden_states + hidden_states
211220
212- # 1. Self-attention
213- with jax .named_scope ("attn1" ):
214- norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
215- control_hidden_states .dtype
216- )
217- attn_output = self .attn1 (
218- hidden_states = norm_hidden_states ,
219- encoder_hidden_states = norm_hidden_states ,
220- rotary_emb = rotary_emb ,
221- deterministic = deterministic ,
222- rngs = rngs ,
221+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
222+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
223223 )
224- control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
225- control_hidden_states .dtype
226- )
227-
228- # 2. Cross-attention
229- with jax .named_scope ("attn2" ):
230- norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
231- attn_output = self .attn2 (
232- hidden_states = norm_hidden_states ,
233- encoder_hidden_states = encoder_hidden_states ,
234- deterministic = deterministic ,
235- rngs = rngs ,
236- )
237- control_hidden_states = control_hidden_states + attn_output
238224
239- # 3. Feed-forward
240- with jax .named_scope ("ffn" ):
241- norm_hidden_states = (self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
242- control_hidden_states .dtype
243- )
244- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
245- control_hidden_states = (
246- control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
247- ).astype (control_hidden_states .dtype )
248- conditioning_states = None
249- if self .apply_output_projection :
250- conditioning_states = self .proj_out (control_hidden_states )
225+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_heads" ))
226+ control_hidden_states = jax .lax .with_sharding_constraint (control_hidden_states , axis_names )
227+ control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
228+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_kv" ))
229+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
230+
231+ # 1. Self-attention
232+ with self .conditional_named_scope ("self_attn" ):
233+ with self .conditional_named_scope ("self_attn_norm" ):
234+ norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
235+ control_hidden_states .dtype
236+ )
237+ with self .conditional_named_scope ("self_attn_attn" ):
238+ attn_output = self .attn1 (
239+ hidden_states = norm_hidden_states ,
240+ encoder_hidden_states = norm_hidden_states ,
241+ rotary_emb = rotary_emb ,
242+ deterministic = deterministic ,
243+ rngs = rngs ,
244+ )
245+ with self .conditional_named_scope ("self_attn_residual" ):
246+ control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
247+ control_hidden_states .dtype
248+ )
251249
252- return conditioning_states , control_hidden_states
250+ # 2. Cross-attention
251+ with self .conditional_named_scope ("cross_attn" ):
252+ with self .conditional_named_scope ("cross_attn_norm" ):
253+ norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
254+ with self .conditional_named_scope ("cross_attn_attn" ):
255+ attn_output = self .attn2 (
256+ hidden_states = norm_hidden_states ,
257+ encoder_hidden_states = encoder_hidden_states ,
258+ deterministic = deterministic ,
259+ rngs = rngs ,
260+ )
261+ with self .conditional_named_scope ("cross_attn_residual" ):
262+ control_hidden_states = control_hidden_states + attn_output
263+
264+ # 3. Feed-forward
265+ with self .conditional_named_scope ("mlp" ):
266+ with self .conditional_named_scope ("mlp_norm" ):
267+ norm_hidden_states = (
268+ self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa
269+ ).astype (control_hidden_states .dtype )
270+ with self .conditional_named_scope ("mlp_ffn" ):
271+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
272+ with self .conditional_named_scope ("mlp_residual" ):
273+ control_hidden_states = (
274+ control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
275+ ).astype (control_hidden_states .dtype )
276+
277+ with self .conditional_named_scope ("output_projection" ):
278+ conditioning_states = None
279+ if self .apply_output_projection :
280+ conditioning_states = self .proj_out (control_hidden_states )
281+
282+ return conditioning_states , control_hidden_states
253283
254284
255285class WanVACEModel (WanModel ):
@@ -289,7 +319,11 @@ def __init__(
289319 remat_policy : str = "None" ,
290320 names_which_can_be_saved : list [str ] = [],
291321 names_which_can_be_offloaded : list [str ] = [],
322+ mask_padding_tokens : bool = True ,
292323 scan_layers : bool = True ,
324+ enable_jax_named_scopes : bool = False ,
325+ use_base2_exp : bool = False ,
326+ use_experimental_scheduler : bool = False ,
293327 ):
294328 """Initializes the VACE model.
295329
@@ -302,6 +336,7 @@ def __init__(
302336 out_channels = out_channels or in_channels
303337 self .num_layers = num_layers
304338 self .scan_layers = scan_layers
339+ self .enable_jax_named_scopes = enable_jax_named_scopes
305340
306341 # 1. Patch & position embedding
307342 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -329,6 +364,7 @@ def __init__(
329364 text_embed_dim = text_dim ,
330365 image_embed_dim = image_dim ,
331366 pos_embed_seq_len = pos_embed_seq_len ,
367+ flash_min_seq_length = flash_min_seq_length ,
332368 )
333369
334370 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -358,6 +394,10 @@ def __init__(
358394 precision = precision ,
359395 attention = attention ,
360396 dropout = dropout ,
397+ mask_padding_tokens = mask_padding_tokens ,
398+ enable_jax_named_scopes = enable_jax_named_scopes ,
399+ use_base2_exp = use_base2_exp ,
400+ use_experimental_scheduler = use_experimental_scheduler ,
361401 )
362402 blocks .append (block )
363403 self .blocks = blocks
@@ -384,8 +424,12 @@ def __init__(
384424 precision = precision ,
385425 attention = attention ,
386426 dropout = dropout ,
427+ mask_padding_tokens = mask_padding_tokens ,
428+ enable_jax_named_scopes = enable_jax_named_scopes ,
387429 apply_input_projection = vace_block_id == 0 ,
388430 apply_output_projection = True ,
431+ use_base2_exp = use_base2_exp ,
432+ use_experimental_scheduler = use_experimental_scheduler ,
389433 )
390434 vace_blocks .append (vace_block )
391435 self .vace_blocks = vace_blocks
@@ -421,6 +465,10 @@ def __init__(
421465 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
422466 )
423467
468+ def conditional_named_scope (self , name : str ):
469+ """Return a JAX named scope if enabled, otherwise a null context."""
470+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
471+
424472 @jax .named_scope ("WanVACEModel" )
425473 def __call__ (
426474 self ,
@@ -436,7 +484,7 @@ def __call__(
436484 rngs : nnx .Rngs = None ,
437485 ) -> jax .Array :
438486 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
439- batch_size , num_channels , num_frames , height , width = hidden_states .shape
487+ batch_size , _ , num_frames , height , width = hidden_states .shape
440488 p_t , p_h , p_w = self .config .patch_size
441489 post_patch_num_frames = num_frames // p_t
442490 post_patch_height = height // p_h
@@ -453,13 +501,14 @@ def __call__(
453501
454502 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
455503 control_hidden_states = jnp .transpose (control_hidden_states , (0 , 2 , 3 , 4 , 1 ))
456- rotary_emb = self .rope (hidden_states )
457-
458- hidden_states = self .patch_embedding (hidden_states )
459- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
460-
461- control_hidden_states = self .vace_patch_embedding (control_hidden_states )
462- control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
504+ with self .conditional_named_scope ("rotary_embedding" ):
505+ rotary_emb = self .rope (hidden_states )
506+ with self .conditional_named_scope ("patch_embedding" ):
507+ hidden_states = self .patch_embedding (hidden_states )
508+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
509+
510+ control_hidden_states = self .vace_patch_embedding (control_hidden_states )
511+ control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
463512 control_hidden_states_padding = jnp .zeros ((
464513 batch_size ,
465514 control_hidden_states .shape [1 ],
@@ -469,16 +518,17 @@ def __call__(
469518 control_hidden_states = jnp .concatenate ([control_hidden_states , control_hidden_states_padding ], axis = 2 )
470519
471520 # Condition embedder is a FC layer.
472- (
473- temb ,
474- timestep_proj ,
475- encoder_hidden_states ,
476- encoder_hidden_states_image ,
477- _ ,
478- ) = self .condition_embedder ( # We will need to mask out the text embedding.
479- timestep , encoder_hidden_states , encoder_hidden_states_image
480- )
481- timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
521+ with self .conditional_named_scope ("condition_embedder" ):
522+ (
523+ temb ,
524+ timestep_proj ,
525+ encoder_hidden_states ,
526+ encoder_hidden_states_image ,
527+ _ ,
528+ ) = self .condition_embedder ( # We will need to mask out the text embedding.
529+ timestep , encoder_hidden_states , encoder_hidden_states_image
530+ )
531+ timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
482532
483533 if encoder_hidden_states_image is not None :
484534 raise NotImplementedError ("img2vid is not yet implemented." )
0 commit comments