Skip to content

Commit f5342fd

Browse files
committed
fix for dtypes
1 parent a996cce commit f5342fd

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,8 +1067,8 @@ def __init__(
10671067
latents_mean: List[float] = [],
10681068
latents_std: List[float] = [],
10691069
mesh: jax.sharding.Mesh = None,
1070-
dtype: jnp.dtype = jnp.float32,
1071-
weights_dtype: jnp.dtype = jnp.float32,
1070+
dtype: jnp.dtype = jnp.bfloat16,
1071+
weights_dtype: jnp.dtype = jnp.bfloat16,
10721072
precision: jax.lax.Precision = None,
10731073
):
10741074
self.z_dim = z_dim
@@ -1132,6 +1132,7 @@ def encode(
11321132
if x.shape[-1] != 3:
11331133
x = jnp.transpose(x, (0, 2, 3, 4, 1))
11341134

1135+
x = x.astype(jnp.bfloat16)
11351136
x_scan = jnp.swapaxes(x, 0, 1)
11361137
b, t, h, w, c = x.shape
11371138
init_cache = self.encoder.init_cache(b, h, w, jnp.bfloat16)
@@ -1161,12 +1162,12 @@ def decode(
11611162
) -> Union[FlaxDecoderOutput, jax.Array]:
11621163
if z.shape[-1] != self.z_dim:
11631164
z = jnp.transpose(z, (0, 2, 3, 4, 1))
1164-
1165+
z = z.astype(jnp.bfloat16)
11651166
x, _ = self.post_quant_conv(z)
11661167
x_scan = jnp.swapaxes(x, 0, 1)
11671168

11681169
b, t, h, w, c = x.shape
1169-
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
1170+
init_cache = self.decoder.init_cache(b, h, w, jnp.bfloat16)
11701171

11711172
def scan_fn(carry, input_slice):
11721173
# Expand Time dimension for Conv3d
@@ -1189,6 +1190,7 @@ def scan_fn(carry, input_slice):
11891190
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
11901191

11911192
out = jnp.clip(decoded, min=-1.0, max=1.0)
1193+
out = out.astype(jnp.float32)
11921194

11931195
if not return_dict:
11941196
return (out,)

0 commit comments

Comments
 (0)