We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7918a6b commit eaa96e5Copy full SHA for eaa96e5
1 file changed
src/maxdiffusion/models/wan/autoencoder_kl_wan.py
@@ -612,6 +612,7 @@ def scan_fn(carry, input_slice):
612
# Expand Time dimension for Conv3d
613
input_slice = jnp.expand_dims(input_slice, 1)
614
out_slice, new_carry = self.decoder(input_slice, carry)
615
+ out_slice = out_slice.astype(jnp.bfloat16)
616
# Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
617
return new_carry, out_slice
618
0 commit comments