@@ -374,6 +374,7 @@ def __call__(
374374 deterministic : bool = True ,
375375 rngs : nnx .Rngs = None ,
376376 encoder_attention_mask : Optional [jax .Array ] = None ,
377+ cached_kv : Optional [Dict [str , Tuple [jax .Array , jax .Array ]]] = None ,
377378 ):
378379 with self .conditional_named_scope ("transformer_block" ):
379380 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
@@ -413,6 +414,7 @@ def __call__(
413414 deterministic = deterministic ,
414415 rngs = rngs ,
415416 encoder_attention_mask = encoder_attention_mask ,
417+ cached_kv = cached_kv ,
416418 )
417419 with self .conditional_named_scope ("cross_attn_residual" ):
418420 hidden_states = hidden_states + attn_output
@@ -431,6 +433,13 @@ def __call__(
431433 )
432434 return hidden_states
433435
436+ def compute_kv (
437+ self ,
438+ encoder_hidden_states : jax .Array ,
439+ encoder_attention_mask : Optional [jax .Array ] = None ,
440+ ) -> Dict [str , Tuple [jax .Array , jax .Array ]]:
441+ return self .attn2 .compute_kv (encoder_hidden_states , encoder_attention_mask )
442+
434443
435444class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
436445
@@ -583,6 +592,53 @@ def conditional_named_scope(self, name: str):
583592 """Return a JAX named scope if enabled, otherwise a null context."""
584593 return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
585594
595+ def compute_kv_cache (
596+ self ,
597+ encoder_hidden_states : jax .Array ,
598+ encoder_hidden_states_image : Optional [jax .Array ] = None ,
599+ timestep : Optional [jax .Array ] = None ,
600+ ) -> Dict [str , Tuple [jax .Array , jax .Array ]]:
601+ if timestep is None :
602+ batch_size = encoder_hidden_states .shape [0 ]
603+ timestep = jnp .zeros ((batch_size ,), dtype = jnp .int32 )
604+
605+ with self .conditional_named_scope ("condition_embedder" ):
606+ (
607+ temb ,
608+ timestep_proj ,
609+ encoder_hidden_states ,
610+ encoder_hidden_states_image ,
611+ encoder_attention_mask ,
612+ ) = self .condition_embedder (timestep , encoder_hidden_states , encoder_hidden_states_image )
613+
614+ if encoder_hidden_states_image is not None :
615+ encoder_hidden_states = jnp .concatenate ([encoder_hidden_states_image , encoder_hidden_states ], axis = 1 )
616+ if encoder_attention_mask is not None :
617+ text_mask = jnp .ones (
618+ (encoder_hidden_states .shape [0 ], encoder_hidden_states .shape [1 ] - encoder_hidden_states_image .shape [1 ]),
619+ dtype = jnp .int32 ,
620+ )
621+ encoder_attention_mask = jnp .concatenate ([encoder_attention_mask , text_mask ], axis = 1 )
622+
623+ if self .scan_layers :
624+ @nnx .vmap (in_axes = (0 , None , None ), out_axes = 0 , transform_metadata = {nnx .PARTITION_NAME : "layers_per_stage" })
625+ def _compute_kv (block , enc_states , enc_mask ):
626+ return block .compute_kv (enc_states , enc_mask )
627+
628+ kv_cache = _compute_kv (self .blocks , encoder_hidden_states , encoder_attention_mask )
629+ else :
630+ kv_cache_list = []
631+ for block in self .blocks :
632+ kv_cache_list .append (block .compute_kv (encoder_hidden_states , encoder_attention_mask ))
633+ keys = kv_cache_list [0 ].keys ()
634+ kv_cache = {}
635+ for k in keys :
636+ k_list = [d [k ][0 ] for d in kv_cache_list ]
637+ v_list = [d [k ][1 ] for d in kv_cache_list ]
638+ kv_cache [k ] = (jnp .stack (k_list , axis = 0 ), jnp .stack (v_list , axis = 0 ))
639+
640+ return kv_cache
641+
586642 @jax .named_scope ("WanModel" )
587643 def __call__ (
588644 self ,
@@ -597,6 +653,7 @@ def __call__(
597653 skip_blocks : Optional [jax .Array ] = None ,
598654 cached_residual : Optional [jax .Array ] = None ,
599655 return_residual : bool = False ,
656+ kv_cache : Optional [Dict [str , Tuple [jax .Array , jax .Array ]]] = None ,
600657 ) -> Union [jax .Array , Tuple [jax .Array , jax .Array ], Dict [str , jax .Array ]]:
601658 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
602659 batch_size , _ , num_frames , height , width = hidden_states .shape
@@ -634,8 +691,14 @@ def __call__(
634691 def _run_all_blocks (h ):
635692 if self .scan_layers :
636693
637- def scan_fn (carry , block ):
694+ def scan_fn (carry , block_input ):
638695 hidden_states_carry , rngs_carry = carry
696+ if kv_cache is not None :
697+ block , layer_kv_cache = block_input
698+ else :
699+ block = block_input
700+ layer_kv_cache = None
701+
639702 hidden_states = block (
640703 hidden_states_carry ,
641704 encoder_hidden_states ,
@@ -644,6 +707,7 @@ def scan_fn(carry, block):
644707 deterministic ,
645708 rngs_carry ,
646709 encoder_attention_mask ,
710+ cached_kv = layer_kv_cache ,
647711 )
648712 new_carry = (hidden_states , rngs_carry )
649713 return new_carry , None
@@ -652,19 +716,28 @@ def scan_fn(carry, block):
652716 scan_fn , self .names_which_can_be_saved , self .names_which_can_be_offloaded , prevent_cse = not self .scan_layers
653717 )
654718 initial_carry = (h , rngs )
719+
720+ if kv_cache is not None :
721+ scan_input = (self .blocks , kv_cache )
722+ else :
723+ scan_input = self .blocks
724+
655725 final_carry , _ = nnx .scan (
656726 rematted_block_forward ,
657727 length = self .num_layers ,
658728 in_axes = (nnx .Carry , 0 ),
659729 out_axes = (nnx .Carry , 0 ),
660- )(initial_carry , self . blocks )
730+ )(initial_carry , scan_input )
661731
662732 h_out , _ = final_carry
663733 else :
664734 h_out = h
665- for block in self .blocks :
735+ for i , block in enumerate (self .blocks ):
736+ layer_kv_cache = None
737+ if kv_cache is not None :
738+ layer_kv_cache = jax .tree_map (lambda x : x [i ], kv_cache )
666739
667- def layer_forward (hidden_states ):
740+ def layer_forward (hidden_states , l_kv ):
668741 return block (
669742 hidden_states ,
670743 encoder_hidden_states ,
@@ -673,6 +746,7 @@ def layer_forward(hidden_states):
673746 deterministic ,
674747 rngs ,
675748 encoder_attention_mask = encoder_attention_mask ,
749+ cached_kv = l_kv ,
676750 )
677751
678752 rematted_layer_forward = self .gradient_checkpoint .apply (
@@ -681,7 +755,7 @@ def layer_forward(hidden_states):
681755 self .names_which_can_be_offloaded ,
682756 prevent_cse = not self .scan_layers ,
683757 )
684- h_out = rematted_layer_forward (h_out )
758+ h_out = rematted_layer_forward (h_out , layer_kv_cache )
685759 return h_out
686760
687761 hidden_states_before_blocks = hidden_states
0 commit comments