@@ -61,8 +61,8 @@ def __init__(
6161 padding : Union [int , Tuple [int , int , int ]] = 0 ,
6262 use_bias : bool = True ,
6363 mesh : jax .sharding .Mesh = None ,
64- dtype : jnp .dtype = jnp .float32 ,
65- weights_dtype : jnp .dtype = jnp .float32 ,
64+ dtype : jnp .dtype = jnp .bfloat16 ,
65+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
6666 precision : jax .lax .Precision = None ,
6767 ):
6868 self .kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
@@ -270,8 +270,8 @@ def __init__(
270270 mode : str ,
271271 rngs : nnx .Rngs ,
272272 mesh : jax .sharding .Mesh = None ,
273- dtype : jnp .dtype = jnp .float32 ,
274- weights_dtype : jnp .dtype = jnp .float32 ,
273+ dtype : jnp .dtype = jnp .bfloat16 ,
274+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
275275 precision : jax .lax .Precision = None ,
276276 ):
277277 self .dtype = dtype
@@ -443,8 +443,8 @@ def __init__(
443443 dropout : float = 0.0 ,
444444 non_linearity : str = "silu" ,
445445 mesh : jax .sharding .Mesh = None ,
446- dtype : jnp .dtype = jnp .float32 ,
447- weights_dtype : jnp .dtype = jnp .float32 ,
446+ dtype : jnp .dtype = jnp .bfloat16 ,
447+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
448448 precision : jax .lax .Precision = None ,
449449 ):
450450 self .dtype = dtype
@@ -511,19 +511,19 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
511511 input_dtype = x .dtype
512512
513513 h , sc_cache = self .conv_shortcut (x , cache .get ("shortcut" ))
514- new_cache ["shortcut" ] = sc_cache
514+ new_cache ["shortcut" ] = sc_cache . astype ( self . dtype )
515515
516516 x = self .norm1 (x )
517517 x = self .nonlinearity (x )
518518
519519 x , c1 = self .conv1 (x , cache .get ("conv1" ))
520- new_cache ["conv1" ] = c1
520+ new_cache ["conv1" ] = c1 . astype ( self . dtype )
521521
522522 x = self .norm2 (x )
523523 x = self .nonlinearity (x )
524524
525525 x , c2 = self .conv2 (x , cache .get ("conv2" ))
526- new_cache ["conv2" ] = c2
526+ new_cache ["conv2" ] = c2 . astype ( self . dtype )
527527
528528 x = (x + h ).astype (self .dtype )
529529 return x , new_cache
@@ -535,8 +535,8 @@ def __init__(
535535 dim : int ,
536536 rngs : nnx .Rngs ,
537537 mesh : jax .sharding .Mesh = None ,
538- dtype : jnp .dtype = jnp .float32 ,
539- weights_dtype : jnp .dtype = jnp .float32 ,
538+ dtype : jnp .dtype = jnp .bfloat16 ,
539+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
540540 precision : jax .lax .Precision = None ,
541541 ):
542542 self .dim = dim
@@ -597,10 +597,11 @@ def __init__(
597597 non_linearity : str = "silu" ,
598598 num_layers : int = 1 ,
599599 mesh : jax .sharding .Mesh = None ,
600- dtype : jnp .dtype = jnp .float32 ,
601- weights_dtype : jnp .dtype = jnp .float32 ,
600+ dtype : jnp .dtype = jnp .bfloat16 ,
601+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
602602 precision : jax .lax .Precision = None ,
603603 ):
604+ self .dtype = dtype
604605 self .dim = dim
605606 self .resnets = nnx .List (
606607 [
@@ -657,13 +658,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
657658 new_cache = {"resnets" : []}
658659
659660 x , c = self .resnets [0 ](x , cache .get ("resnets" , [None ])[0 ])
660- new_cache ["resnets" ].append (c )
661+ new_cache ["resnets" ].append (c . astype ( self . dtype ) )
661662
662663 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
663664 if attn is not None :
664665 x = attn (x )
665666 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
666- new_cache ["resnets" ].append (c )
667+ new_cache ["resnets" ].append (c . astype ( self . dtype ) )
667668
668669 return x , new_cache
669670
@@ -679,10 +680,11 @@ def __init__(
679680 upsample_mode : Optional [str ] = None ,
680681 non_linearity : str = "silu" ,
681682 mesh : jax .sharding .Mesh = None ,
682- dtype : jnp .dtype = jnp .float32 ,
683- weights_dtype : jnp .dtype = jnp .float32 ,
683+ dtype : jnp .dtype = jnp .bfloat16 ,
684+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
684685 precision : jax .lax .Precision = None ,
685686 ):
687+ self .dtype = dtype
686688 self .resnets = nnx .List ([])
687689 current_dim = in_dim
688690 for _ in range (num_res_blocks + 1 ):
@@ -736,11 +738,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
736738
737739 for i , resnet in enumerate (self .resnets ):
738740 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i ])
739- new_cache ["resnets" ].append (c )
741+ new_cache ["resnets" ].append (c . astype ( self . dtype ) )
740742
741743 if self .upsamplers :
742744 x , c = self .upsamplers [0 ](x , cache .get ("upsamplers" , [None ])[0 ])
743- new_cache ["upsamplers" ].append (c )
745+ new_cache ["upsamplers" ].append (c . astype ( self . dtype ) )
744746 return x , new_cache
745747
746748
@@ -757,10 +759,11 @@ def __init__(
757759 dropout = 0.0 ,
758760 non_linearity : str = "silu" ,
759761 mesh : jax .sharding .Mesh = None ,
760- dtype : jnp .dtype = jnp .float32 ,
761- weights_dtype : jnp .dtype = jnp .float32 ,
762+ dtype : jnp .dtype = jnp .bfloat16 ,
763+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
762764 precision : jax .lax .Precision = None ,
763765 ):
766+ self .dtype = dtype
764767 self .dim = dim
765768 self .z_dim = z_dim
766769 self .dim_mult = dim_mult
@@ -885,27 +888,27 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
885888 new_cache = {}
886889
887890 x , c = self .conv_in (x , cache .get ("conv_in" ))
888- new_cache ["conv_in" ] = c
891+ new_cache ["conv_in" ] = c . astype ( self . dtype )
889892
890893 new_cache ["down_blocks" ] = []
891894 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
892895
893896 for i , layer in enumerate (self .down_blocks ):
894897 if isinstance (layer , (WanResidualBlock , WanResample )):
895898 x , c = layer (x , current_down_caches [i ])
896- new_cache ["down_blocks" ].append (c )
899+ new_cache ["down_blocks" ].append (c . astype ( self . dtype ) )
897900 else :
898901 x = layer (x )
899902 new_cache ["down_blocks" ].append (None )
900903
901904 x , c = self .mid_block (x , cache .get ("mid_block" ))
902- new_cache ["mid_block" ] = c
905+ new_cache ["mid_block" ] = c . astype ( self . dtype )
903906
904907 x = self .norm_out (x )
905908 x = self .nonlinearity (x )
906909
907910 x , c = self .conv_out (x , cache .get ("conv_out" ))
908- new_cache ["conv_out" ] = c
911+ new_cache ["conv_out" ] = c . astype ( self . dtype )
909912
910913 return x , new_cache
911914
@@ -923,10 +926,11 @@ def __init__(
923926 dropout = 0.0 ,
924927 non_linearity : str = "silu" ,
925928 mesh : jax .sharding .Mesh = None ,
926- dtype : jnp .dtype = jnp .float32 ,
927- weights_dtype : jnp .dtype = jnp .float32 ,
929+ dtype : jnp .dtype = jnp .bfloat16 ,
930+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
928931 precision : jax .lax .Precision = None ,
929932 ):
933+ self .dtype = dtype
930934 self .dim = dim
931935 self .dim_mult = dim_mult
932936 self .nonlinearity = get_activation (non_linearity )
@@ -1022,21 +1026,21 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
10221026 new_cache = {}
10231027
10241028 x , c = self .conv_in (x , cache .get ("conv_in" ))
1025- new_cache ["conv_in" ] = c
1029+ new_cache ["conv_in" ] = c . astype ( self . dtype )
10261030
10271031 x , c = self .mid_block (x , cache .get ("mid_block" ))
1028- new_cache ["mid_block" ] = c
1032+ new_cache ["mid_block" ] = c . astype ( self . dtype )
10291033
10301034 new_cache ["up_blocks" ] = []
10311035 current_up_caches = cache .get ("up_blocks" , [None ] * len (self .up_blocks ))
10321036 for i , up_block in enumerate (self .up_blocks ):
10331037 x , c = up_block (x , current_up_caches [i ])
1034- new_cache ["up_blocks" ].append (c )
1038+ new_cache ["up_blocks" ].append (c . astype ( self . dtype ) )
10351039
10361040 x = self .norm_out (x )
10371041 x = self .nonlinearity (x )
10381042 x , c = self .conv_out (x , cache .get ("conv_out" ))
1039- new_cache ["conv_out" ] = c
1043+ new_cache ["conv_out" ] = c . astype ( self . dtype )
10401044
10411045 return x , new_cache
10421046
0 commit comments