We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4e985c3 commit 04ed46dCopy full SHA for 04ed46d
1 file changed
src/maxdiffusion/tests/test_ltx2_utils.py
@@ -185,7 +185,7 @@ def test_load_vocoder_weights(self):
185
else:
186
pt_shape = shape
187
188
- pt_weights[pt_key] = torch.randn(pt_shape)
+ pt_weights[pt_key] = jnp.array(torch.randn(pt_shape).numpy())
189
190
with mock.patch("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint", return_value=pt_weights):
191
loaded_weights = load_vocoder_weights(
0 commit comments