Skip to content

Commit 88b86c4

Browse files
committed
dtype of cache initialisation changed
1 parent 8e313fc commit 88b86c4

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ def encode(
11321132

11331133
x_scan = jnp.swapaxes(x, 0, 1)
11341134
b, t, h, w, c = x.shape
1135-
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1135+
init_cache = self.encoder.init_cache(b, h, w, jnp.bfloat16)
11361136

11371137
def scan_fn(carry, input_slice):
11381138
# Expand Time dimension for Conv3d

0 commit comments

Comments
 (0)