Skip to content

Commit 516cf91

Browse files
committed
fix for dtypes
1 parent 0897096 commit 516cf91

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __call__(
172172
padding_to_apply = tuple(current_padding)
173173
if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads):
174174
x_padded = jnp.pad(
175-
x_input, padding_to_apply, mode="constant", constant_values=0.0
175+
x_input, padding_to_apply, mode="constant", constant_values=jnp.array(0.0, dtype=self.dtype)
176176
)
177177
else:
178178
x_padded = x_input

0 commit comments

Comments
 (0)