1313limitations under the License.
1414"""
1515
16+ import contextlib
1617import math
1718from typing import Any , Dict , Optional , Tuple
1819
@@ -62,8 +63,12 @@ def __init__(
6263 precision : jax .lax .Precision | None = None ,
6364 attention : str = "dot_product" ,
6465 dropout : float = 0.0 ,
66+ mask_padding_tokens : bool = True ,
67+ enable_jax_named_scopes : bool = False ,
6568 apply_input_projection : bool = False ,
6669 apply_output_projection : bool = False ,
70+ use_base2_exp : bool = False ,
71+ use_experimental_scheduler : bool = False ,
6772 ):
6873 """Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095 apply_output_projection: Whether to apply an output projection before
9196 outputting the result.
9297 """
93-
98+ self . enable_jax_named_scopes = enable_jax_named_scopes
9499 self .apply_input_projection = apply_input_projection
95100 self .apply_output_projection = apply_output_projection
96101
@@ -124,7 +129,12 @@ def __init__(
124129 precision = precision ,
125130 attention_kernel = attention ,
126131 dropout = dropout ,
132+ is_self_attention = True ,
133+ mask_padding_tokens = mask_padding_tokens ,
127134 residual_checkpoint_name = "self_attn" ,
135+ enable_jax_named_scopes = enable_jax_named_scopes ,
136+ use_base2_exp = use_base2_exp ,
137+ use_experimental_scheduler = use_experimental_scheduler ,
128138 )
129139
130140 # 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153 precision = precision ,
144154 attention_kernel = attention ,
145155 dropout = dropout ,
156+ is_self_attention = False ,
157+ mask_padding_tokens = mask_padding_tokens ,
146158 residual_checkpoint_name = "cross_attn" ,
159+ enable_jax_named_scopes = enable_jax_named_scopes ,
160+ use_base2_exp = use_base2_exp ,
161+ use_experimental_scheduler = use_experimental_scheduler ,
147162 )
148163 assert cross_attn_norm is True , "cross_attn_norm must be True"
149164 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -158,6 +173,7 @@ def __init__(
158173 weights_dtype = weights_dtype ,
159174 precision = precision ,
160175 dropout = dropout ,
176+ enable_jax_named_scopes = enable_jax_named_scopes ,
161177 )
162178
163179 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
@@ -180,6 +196,10 @@ def __init__(
180196 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
181197 )
182198
199+ def conditional_named_scope (self , name : str ):
200+ """Return a JAX named scope if enabled, otherwise a null context."""
201+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
202+
183203 def __call__ (
184204 self ,
185205 * ,
@@ -191,65 +211,76 @@ def __call__(
191211 deterministic : bool = True ,
192212 rngs : nnx .Rngs | None = None ,
193213 ) -> Tuple [jax .Array , jax .Array ]:
194- if self .apply_input_projection :
195- control_hidden_states = self .proj_in (control_hidden_states )
196- control_hidden_states = control_hidden_states + hidden_states
197-
198- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
199- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
200- )
201-
202- control_hidden_states = jax .lax .with_sharding_constraint (
203- control_hidden_states ,
204- PartitionSpec ("data" , "fsdp" , "tensor" ),
205- )
206- control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
207- encoder_hidden_states = jax .lax .with_sharding_constraint (
208- encoder_hidden_states ,
209- PartitionSpec ("data" , "fsdp" , None ),
210- )
211-
212- # 1. Self-attention
213- with jax .named_scope ("attn1" ):
214- norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
215- control_hidden_states .dtype
216- )
217- attn_output = self .attn1 (
218- hidden_states = norm_hidden_states ,
219- encoder_hidden_states = norm_hidden_states ,
220- rotary_emb = rotary_emb ,
221- deterministic = deterministic ,
222- rngs = rngs ,
223- )
224- control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
225- control_hidden_states .dtype
226- )
227214
228- # 2. Cross-attention
229- with jax .named_scope ("attn2" ):
230- norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
231- attn_output = self .attn2 (
232- hidden_states = norm_hidden_states ,
233- encoder_hidden_states = encoder_hidden_states ,
234- deterministic = deterministic ,
235- rngs = rngs ,
215+ with self .conditional_named_scope ("vace_transformer_block" ):
216+ with self .conditional_named_scope ("input_projection" ):
217+ if self .apply_input_projection :
218+ control_hidden_states = self .proj_in (control_hidden_states )
219+ control_hidden_states = control_hidden_states + hidden_states
220+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
221+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )),
222+ 6 ,
223+ axis = 1 ,
236224 )
237- control_hidden_states = control_hidden_states + attn_output
238225
239- # 3. Feed-forward
240- with jax .named_scope ("ffn" ):
241- norm_hidden_states = (self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
242- control_hidden_states .dtype
243- )
244- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
245- control_hidden_states = (
246- control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
247- ).astype (control_hidden_states .dtype )
248- conditioning_states = None
249- if self .apply_output_projection :
250- conditioning_states = self .proj_out (control_hidden_states )
226+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_heads" ))
227+ control_hidden_states = jax .lax .with_sharding_constraint (control_hidden_states , axis_names )
228+ control_hidden_states = checkpoint_name (control_hidden_states , "control_hidden_states" )
229+ axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_kv" ))
230+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
231+
232+ # 1. Self-attention
233+ with self .conditional_named_scope ("self_attn" ):
234+ with self .conditional_named_scope ("self_attn_norm" ):
235+ norm_hidden_states = (self .norm1 (control_hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
236+ control_hidden_states .dtype
237+ )
238+ with self .conditional_named_scope ("self_attn_attn" ):
239+ attn_output = self .attn1 (
240+ hidden_states = norm_hidden_states ,
241+ encoder_hidden_states = norm_hidden_states ,
242+ rotary_emb = rotary_emb ,
243+ deterministic = deterministic ,
244+ rngs = rngs ,
245+ )
246+ with self .conditional_named_scope ("self_attn_residual" ):
247+ control_hidden_states = (control_hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (
248+ control_hidden_states .dtype
249+ )
251250
252- return conditioning_states , control_hidden_states
251+ # 2. Cross-attention
252+ with self .conditional_named_scope ("cross_attn" ):
253+ with self .conditional_named_scope ("cross_attn_norm" ):
254+ norm_hidden_states = self .norm2 (control_hidden_states .astype (jnp .float32 )).astype (control_hidden_states .dtype )
255+ with self .conditional_named_scope ("cross_attn_attn" ):
256+ attn_output = self .attn2 (
257+ hidden_states = norm_hidden_states ,
258+ encoder_hidden_states = encoder_hidden_states ,
259+ deterministic = deterministic ,
260+ rngs = rngs ,
261+ )
262+ with self .conditional_named_scope ("cross_attn_residual" ):
263+ control_hidden_states = control_hidden_states + attn_output
264+
265+ # 3. Feed-forward
266+ with self .conditional_named_scope ("mlp" ):
267+ with self .conditional_named_scope ("mlp_norm" ):
268+ norm_hidden_states = (
269+ self .norm3 (control_hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa
270+ ).astype (control_hidden_states .dtype )
271+ with self .conditional_named_scope ("mlp_ffn" ):
272+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
273+ with self .conditional_named_scope ("mlp_residual" ):
274+ control_hidden_states = (
275+ control_hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa
276+ ).astype (control_hidden_states .dtype )
277+
278+ with self .conditional_named_scope ("output_projection" ):
279+ conditioning_states = None
280+ if self .apply_output_projection :
281+ conditioning_states = self .proj_out (control_hidden_states )
282+
283+ return conditioning_states , control_hidden_states
253284
254285
255286class WanVACEModel (WanModel ):
@@ -289,7 +320,11 @@ def __init__(
289320 remat_policy : str = "None" ,
290321 names_which_can_be_saved : list [str ] = [],
291322 names_which_can_be_offloaded : list [str ] = [],
323+ mask_padding_tokens : bool = True ,
292324 scan_layers : bool = True ,
325+ enable_jax_named_scopes : bool = False ,
326+ use_base2_exp : bool = False ,
327+ use_experimental_scheduler : bool = False ,
293328 ):
294329 """Initializes the VACE model.
295330
@@ -302,6 +337,7 @@ def __init__(
302337 out_channels = out_channels or in_channels
303338 self .num_layers = num_layers
304339 self .scan_layers = scan_layers
340+ self .enable_jax_named_scopes = enable_jax_named_scopes
305341
306342 # 1. Patch & position embedding
307343 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -329,6 +365,7 @@ def __init__(
329365 text_embed_dim = text_dim ,
330366 image_embed_dim = image_dim ,
331367 pos_embed_seq_len = pos_embed_seq_len ,
368+ flash_min_seq_length = flash_min_seq_length ,
332369 )
333370
334371 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -358,6 +395,10 @@ def __init__(
358395 precision = precision ,
359396 attention = attention ,
360397 dropout = dropout ,
398+ mask_padding_tokens = mask_padding_tokens ,
399+ enable_jax_named_scopes = enable_jax_named_scopes ,
400+ use_base2_exp = use_base2_exp ,
401+ use_experimental_scheduler = use_experimental_scheduler ,
361402 )
362403 blocks .append (block )
363404 self .blocks = blocks
@@ -384,8 +425,12 @@ def __init__(
384425 precision = precision ,
385426 attention = attention ,
386427 dropout = dropout ,
428+ mask_padding_tokens = mask_padding_tokens ,
429+ enable_jax_named_scopes = enable_jax_named_scopes ,
387430 apply_input_projection = vace_block_id == 0 ,
388431 apply_output_projection = True ,
432+ use_base2_exp = use_base2_exp ,
433+ use_experimental_scheduler = use_experimental_scheduler ,
389434 )
390435 vace_blocks .append (vace_block )
391436 self .vace_blocks = vace_blocks
@@ -421,6 +466,10 @@ def __init__(
421466 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
422467 )
423468
469+ def conditional_named_scope (self , name : str ):
470+ """Return a JAX named scope if enabled, otherwise a null context."""
471+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
472+
424473 @jax .named_scope ("WanVACEModel" )
425474 def __call__ (
426475 self ,
@@ -436,7 +485,7 @@ def __call__(
436485 rngs : nnx .Rngs = None ,
437486 ) -> jax .Array :
438487 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
439- batch_size , num_channels , num_frames , height , width = hidden_states .shape
488+ batch_size , _ , num_frames , height , width = hidden_states .shape
440489 p_t , p_h , p_w = self .config .patch_size
441490 post_patch_num_frames = num_frames // p_t
442491 post_patch_height = height // p_h
@@ -453,32 +502,34 @@ def __call__(
453502
454503 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
455504 control_hidden_states = jnp .transpose (control_hidden_states , (0 , 2 , 3 , 4 , 1 ))
456- rotary_emb = self .rope (hidden_states )
457-
458- hidden_states = self .patch_embedding (hidden_states )
459- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
460-
461- control_hidden_states = self .vace_patch_embedding (control_hidden_states )
462- control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
505+ with self .conditional_named_scope ("rotary_embedding" ):
506+ rotary_emb = self .rope (hidden_states )
507+ with self .conditional_named_scope ("patch_embedding" ):
508+ hidden_states = self .patch_embedding (hidden_states )
509+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
510+
511+ control_hidden_states = self .vace_patch_embedding (control_hidden_states )
512+ control_hidden_states = jax .lax .collapse (control_hidden_states , 1 , - 1 )
463513 control_hidden_states_padding = jnp .zeros ((
464- batch_size ,
465- control_hidden_states .shape [1 ],
466- hidden_states .shape [2 ] - control_hidden_states .shape [2 ],
514+ batch_size ,
515+ control_hidden_states .shape [1 ],
516+ hidden_states .shape [2 ] - control_hidden_states .shape [2 ],
467517 ))
468518
469519 control_hidden_states = jnp .concatenate ([control_hidden_states , control_hidden_states_padding ], axis = 2 )
470520
471521 # Condition embedder is a FC layer.
472- (
473- temb ,
474- timestep_proj ,
475- encoder_hidden_states ,
476- encoder_hidden_states_image ,
477- _ ,
478- ) = self .condition_embedder ( # We will need to mask out the text embedding.
479- timestep , encoder_hidden_states , encoder_hidden_states_image
480- )
481- timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
522+ with self .conditional_named_scope ("condition_embedder" ):
523+ (
524+ temb ,
525+ timestep_proj ,
526+ encoder_hidden_states ,
527+ encoder_hidden_states_image ,
528+ _ ,
529+ ) = self .condition_embedder ( # We will need to mask out the text embedding.
530+ timestep , encoder_hidden_states , encoder_hidden_states_image
531+ )
532+ timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
482533
483534 if encoder_hidden_states_image is not None :
484535 raise NotImplementedError ("img2vid is not yet implemented." )
0 commit comments