@@ -1067,8 +1067,8 @@ def __init__(
10671067 latents_mean : List [float ] = [],
10681068 latents_std : List [float ] = [],
10691069 mesh : jax .sharding .Mesh = None ,
1070- dtype : jnp .dtype = jnp .float32 ,
1071- weights_dtype : jnp .dtype = jnp .float32 ,
1070+ dtype : jnp .dtype = jnp .bfloat16 ,
1071+ weights_dtype : jnp .dtype = jnp .bfloat16 ,
10721072 precision : jax .lax .Precision = None ,
10731073 ):
10741074 self .z_dim = z_dim
@@ -1132,6 +1132,7 @@ def encode(
11321132 if x .shape [- 1 ] != 3 :
11331133 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
11341134
1135+ x = x .astype (jnp .bfloat16 )
11351136 x_scan = jnp .swapaxes (x , 0 , 1 )
11361137 b , t , h , w , c = x .shape
11371138 init_cache = self .encoder .init_cache (b , h , w , jnp .bfloat16 )
@@ -1161,12 +1162,12 @@ def decode(
11611162 ) -> Union [FlaxDecoderOutput , jax .Array ]:
11621163 if z .shape [- 1 ] != self .z_dim :
11631164 z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
1164-
1165+ z = z . astype ( jnp . bfloat16 )
11651166 x , _ = self .post_quant_conv (z )
11661167 x_scan = jnp .swapaxes (x , 0 , 1 )
11671168
11681169 b , t , h , w , c = x .shape
1169- init_cache = self .decoder .init_cache (b , h , w , x . dtype )
1170+ init_cache = self .decoder .init_cache (b , h , w , jnp . bfloat16 )
11701171
11711172 def scan_fn (carry , input_slice ):
11721173 # Expand Time dimension for Conv3d
@@ -1189,6 +1190,7 @@ def scan_fn(carry, input_slice):
11891190 decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
11901191
11911192 out = jnp .clip (decoded , min = - 1.0 , max = 1.0 )
1193+ out = out .astype (jnp .float32 )
11921194
11931195 if not return_dict :
11941196 return (out ,)
0 commit comments