@@ -661,13 +661,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
661661 new_cache = {"resnets" : []}
662662
663663 x , c = self .resnets [0 ](x , cache .get ("resnets" , [None ])[0 ])
664- new_cache ["resnets" ].append (c . astype ( self . dtype ) )
664+ new_cache ["resnets" ].append (c )
665665
666666 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
667667 if attn is not None :
668668 x = attn (x )
669669 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
670- new_cache ["resnets" ].append (c . astype ( self . dtype ) )
670+ new_cache ["resnets" ].append (c )
671671
672672 return x , new_cache
673673
@@ -741,11 +741,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
741741
742742 for i , resnet in enumerate (self .resnets ):
743743 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i ])
744- new_cache ["resnets" ].append (c . astype ( self . dtype ) )
744+ new_cache ["resnets" ].append (c )
745745
746746 if self .upsamplers :
747747 x , c = self .upsamplers [0 ](x , cache .get ("upsamplers" , [None ])[0 ])
748- new_cache ["upsamplers" ].append (c . astype ( self . dtype ) )
748+ new_cache ["upsamplers" ].append (c )
749749 return x , new_cache
750750
751751
@@ -899,7 +899,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
899899 for i , layer in enumerate (self .down_blocks ):
900900 if isinstance (layer , (WanResidualBlock , WanResample )):
901901 x , c = layer (x , current_down_caches [i ])
902- new_cache ["down_blocks" ].append (c . astype ( self . dtype ) )
902+ new_cache ["down_blocks" ].append (c )
903903 else :
904904 x = layer (x )
905905 new_cache ["down_blocks" ].append (None )
@@ -1038,7 +1038,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
10381038 current_up_caches = cache .get ("up_blocks" , [None ] * len (self .up_blocks ))
10391039 for i , up_block in enumerate (self .up_blocks ):
10401040 x , c = up_block (x , current_up_caches [i ])
1041- new_cache ["up_blocks" ].append (c . astype ( self . dtype ) )
1041+ new_cache ["up_blocks" ].append (c )
10421042
10431043 x = self .norm_out (x )
10441044 x = self .nonlinearity (x )
0 commit comments