Skip to content

Commit eaa96e5

Browse files
committed
Fix
1 parent 7918a6b commit eaa96e5

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def scan_fn(carry, input_slice):
612612
# Expand Time dimension for Conv3d
613613
input_slice = jnp.expand_dims(input_slice, 1)
614614
out_slice, new_carry = self.decoder(input_slice, carry)
615+
out_slice = out_slice.astype(jnp.bfloat16)
615616
# Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
616617
return new_carry, out_slice
617618

0 commit comments

Comments
 (0)