Skip to content

Commit 04ed46d

Browse files
committed
fix
1 parent 4e985c3 commit 04ed46d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_load_vocoder_weights(self):
185185
else:
186186
pt_shape = shape
187187

188-
pt_weights[pt_key] = torch.randn(pt_shape)
188+
pt_weights[pt_key] = jnp.array(torch.randn(pt_shape).numpy())
189189

190190
with mock.patch("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint", return_value=pt_weights):
191191
loaded_weights = load_vocoder_weights(

0 commit comments

Comments
 (0)