Skip to content

Commit dcfdb4a

Browse files
committed
fix linting
1 parent 6313546 commit dcfdb4a

1 file changed

Lines changed: 2 additions & 13 deletions

File tree

src/maxdiffusion/tests/wan_animate/test_transformer_wan_animate.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,14 +16,9 @@
1616

1717
import os
1818
import torch
19-
import torch.nn as nn
20-
import torch.nn.functional as F
2119
import numpy as np
2220
import jax.numpy as jnp
2321
from flax import nnx
24-
import math
25-
import sys
26-
import pytest
2722

2823
# Import from the codebase
2924
# Make sure to set Python path or import correctly
@@ -42,9 +37,6 @@
4237
from jax.sharding import Mesh
4338

4439

45-
46-
47-
4840
def transfer_conv_weights(pt_conv, jax_conv):
4941
if hasattr(jax_conv, "weight"):
5042
jax_conv.weight[...] = jnp.array(pt_conv.weight.detach().numpy())
@@ -227,8 +219,6 @@ def transfer_transformer_weights(pt_model, jax_model):
227219
jax_model.scale_shift_table[...] = jnp.array(pt_model.scale_shift_table.detach().numpy())
228220

229221

230-
231-
232222
class TestWanAnimateTransformer:
233223

234224
def test_motion_conv_equivalence(self):
@@ -244,6 +234,7 @@ def test_motion_conv_equivalence(self):
244234
x_jax = jnp.array(x_np)
245235

246236
from diffusers.models.transformers.transformer_wan_animate import MotionConv2d
237+
247238
pt_model = MotionConv2d(
248239
in_channels=C_in,
249240
out_channels=C_out,
@@ -697,5 +688,3 @@ def test_equivalence_wan_animate_transformer(self):
697688

698689
assert np_pt.shape == np_jax.shape
699690
np.testing.assert_allclose(np_pt, np_jax, rtol=1e-4, atol=1e-4)
700-
701-

0 commit comments

Comments
 (0)