Skip to content

Commit cde8ab8

Browse files
committed
fix transformer names
1 parent 1770d86 commit cde8ab8

2 files changed

Lines changed: 36 additions & 35 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
}
4949

5050

51-
class FlaxFusedLeakyReLU(nnx.Module):
51+
class FusedLeakyReLU(nnx.Module):
5252
"""
5353
Fused LeakyRelu with scale factor and channel-wise bias.
5454
"""
@@ -84,7 +84,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
8484
return x
8585

8686

87-
class FlaxMotionConv2d(nnx.Module):
87+
class MotionConv2d(nnx.Module):
8888
"""2-D convolution with EqualizedLR scaling and optional FusedLeakyReLU.
8989
9090
Weights are stored in PyTorch OIHW format (out, in, k, k) as raw nnx.Param
@@ -148,7 +148,7 @@ def __init__(
148148
self.bias = None
149149

150150
if self.use_activation:
151-
self.act_fn = FlaxFusedLeakyReLU(
151+
self.act_fn = FusedLeakyReLU(
152152
rngs=rngs, bias_channels=out_channels, dtype=dtype, weights_dtype=weights_dtype
153153
)
154154
else:
@@ -205,11 +205,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
205205
return x
206206

207207

208-
class FlaxMotionLinear(nnx.Module):
208+
class MotionLinear(nnx.Module):
209209
"""Equalized-LR linear layer with optional FusedLeakyReLU.
210210
211211
Weights are stored in PyTorch (out, in) format as raw nnx.Param — same
212-
reason as FlaxMotionConv2d. No sharding annotations needed (small layer).
212+
reason as MotionConv2d. No sharding annotations needed (small layer).
213213
"""
214214

215215
def __init__(
@@ -238,7 +238,7 @@ def __init__(
238238
self.bias = None
239239

240240
if self.use_activation:
241-
self.act_fn = FlaxFusedLeakyReLU(rngs=rngs, bias_channels=out_dim, dtype=dtype, weights_dtype=weights_dtype)
241+
self.act_fn = FusedLeakyReLU(rngs=rngs, bias_channels=out_dim, dtype=dtype, weights_dtype=weights_dtype)
242242
else:
243243
self.act_fn = None
244244

@@ -258,7 +258,7 @@ def __call__(self, inputs: jax.Array, channel_dim: int = 1) -> jax.Array:
258258
return out
259259

260260

261-
class FlaxMotionEncoderResBlock(nnx.Module):
261+
class MotionEncoderResBlock(nnx.Module):
262262

263263
def __init__(
264264
self,
@@ -276,7 +276,7 @@ def __init__(
276276
self.dtype = dtype
277277

278278
# 3 X 3 Conv + fused leaky ReLU
279-
self.conv1 = FlaxMotionConv2d(
279+
self.conv1 = MotionConv2d(
280280
rngs,
281281
in_channels,
282282
in_channels,
@@ -289,7 +289,7 @@ def __init__(
289289
)
290290

291291
# 3 X 3 Conv + downsample 2x + fused leaky ReLU
292-
self.conv2 = FlaxMotionConv2d(
292+
self.conv2 = MotionConv2d(
293293
rngs,
294294
in_channels,
295295
out_channels,
@@ -303,7 +303,7 @@ def __init__(
303303
)
304304

305305
# 1 X 1 Conv + downsample 2x in skip connection
306-
self.conv_skip = FlaxMotionConv2d(
306+
self.conv_skip = MotionConv2d(
307307
rngs,
308308
in_channels,
309309
out_channels,
@@ -327,7 +327,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
327327
return x_out
328328

329329

330-
class FlaxWanAnimateMotionEncoder(nnx.Module):
330+
class WanAnimateMotionEncoder(nnx.Module):
331331
"""Encodes a face video frame into a motion vector.
332332
333333
All weights in this network are small (the largest is 32×512→16) so
@@ -353,7 +353,7 @@ def __init__(
353353
if channels is None:
354354
channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
355355

356-
self.conv_in = FlaxMotionConv2d(
356+
self.conv_in = MotionConv2d(
357357
rngs, 3, channels[str(size)], 1, use_activation=True, dtype=dtype, weights_dtype=weights_dtype
358358
)
359359

@@ -363,12 +363,12 @@ def __init__(
363363
for i in range(log_size, 2, -1):
364364
out_channels = channels[str(2 ** (i - 1))]
365365
res_blocks.append(
366-
FlaxMotionEncoderResBlock(rngs, in_channels, out_channels, dtype=dtype, weights_dtype=weights_dtype)
366+
MotionEncoderResBlock(rngs, in_channels, out_channels, dtype=dtype, weights_dtype=weights_dtype)
367367
)
368368
in_channels = out_channels
369369
self.res_blocks = nnx.List(res_blocks)
370370

371-
self.conv_out = FlaxMotionConv2d(
371+
self.conv_out = MotionConv2d(
372372
rngs,
373373
in_channels,
374374
style_dim,
@@ -382,9 +382,9 @@ def __init__(
382382

383383
linears = []
384384
for _ in range(motion_blocks - 1):
385-
linears.append(FlaxMotionLinear(rngs, style_dim, style_dim, dtype=dtype, weights_dtype=weights_dtype))
385+
linears.append(MotionLinear(rngs, style_dim, style_dim, dtype=dtype, weights_dtype=weights_dtype))
386386

387-
linears.append(FlaxMotionLinear(rngs, style_dim, motion_dim, dtype=dtype, weights_dtype=weights_dtype))
387+
linears.append(MotionLinear(rngs, style_dim, motion_dim, dtype=dtype, weights_dtype=weights_dtype))
388388
self.motion_network = nnx.List(linears)
389389

390390
key = rngs.params()
@@ -417,7 +417,7 @@ def __call__(self, face_image: jax.Array, channel_dim: int = 1) -> jax.Array:
417417
return motion_vec.astype(original_dtype)
418418

419419

420-
class FlaxWanAnimateFaceEncoder(nnx.Module):
420+
class WanAnimateFaceEncoder(nnx.Module):
421421

422422
def __init__(
423423
self,
@@ -544,7 +544,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
544544
return x
545545

546546

547-
class FlaxWanAnimateFaceBlockCrossAttention(nnx.Module):
547+
class WanAnimateFaceBlockCrossAttention(nnx.Module):
548548

549549
def __init__(
550550
self,
@@ -763,7 +763,7 @@ def __init__(
763763
weights_dtype=weights_dtype,
764764
)
765765

766-
self.motion_encoder = FlaxWanAnimateMotionEncoder(
766+
self.motion_encoder = WanAnimateMotionEncoder(
767767
rngs=rngs,
768768
size=motion_encoder_size,
769769
style_dim=motion_style_dim,
@@ -773,7 +773,7 @@ def __init__(
773773
dtype=dtype,
774774
weights_dtype=weights_dtype,
775775
)
776-
self.face_encoder = FlaxWanAnimateFaceEncoder(
776+
self.face_encoder = WanAnimateFaceEncoder(
777777
rngs=rngs,
778778
in_dim=motion_encoder_dim,
779779
out_dim=inner_dim,
@@ -840,7 +840,7 @@ def init_block(rngs):
840840
face_adapters = []
841841
num_face_adapters = math.ceil(num_layers / inject_face_latents_blocks)
842842
for _ in range(num_face_adapters):
843-
fa = FlaxWanAnimateFaceBlockCrossAttention(
843+
fa = WanAnimateFaceBlockCrossAttention(
844844
rngs=rngs,
845845
dim=inner_dim,
846846
heads=num_attention_heads,
@@ -1081,3 +1081,4 @@ def layer_forward(hidden_states):
10811081
if not return_dict:
10821082
return (hidden_states,)
10831083
return {"sample": hidden_states}
1084+

src/maxdiffusion/tests/wan_animate_module_parity_test.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
from maxdiffusion import pyconfig
4848
from maxdiffusion.max_utils import create_device_mesh
4949
from maxdiffusion.models.wan.transformers.transformer_wan_animate import (
50-
FlaxFusedLeakyReLU,
51-
FlaxMotionConv2d,
52-
FlaxMotionEncoderResBlock,
53-
FlaxMotionLinear,
54-
FlaxWanAnimateFaceBlockCrossAttention,
55-
FlaxWanAnimateFaceEncoder,
56-
FlaxWanAnimateMotionEncoder,
50+
FusedLeakyReLU,
51+
MotionConv2d,
52+
MotionEncoderResBlock,
53+
MotionLinear,
5754
NNXWanAnimateTransformer3DModel,
55+
WanAnimateFaceBlockCrossAttention,
56+
WanAnimateFaceEncoder,
57+
WanAnimateMotionEncoder,
5858
)
5959
from maxdiffusion.models.wan.wan_utils import (
6060
_rename_wan_animate_pt_tuple_key,
@@ -189,7 +189,7 @@ def setUp(self):
189189

190190
def test_fused_leaky_relu_parity(self):
191191
hf_module = HFFusedLeakyReLU(bias_channels=3).eval()
192-
max_module = FlaxFusedLeakyReLU(rngs=self.rngs, bias_channels=3)
192+
max_module = FusedLeakyReLU(rngs=self.rngs, bias_channels=3)
193193
copy_fused_leaky_relu_params(max_module, hf_module)
194194

195195
inputs = torch.randn(2, 3, 4, 5)
@@ -200,7 +200,7 @@ def test_fused_leaky_relu_parity(self):
200200

201201
def test_motion_conv2d_parity(self):
202202
hf_module = HFMotionConv2d(3, 5, kernel_size=3, stride=2, padding=0, blur_kernel=(1, 3, 3, 1)).eval()
203-
max_module = FlaxMotionConv2d(
203+
max_module = MotionConv2d(
204204
rngs=self.rngs,
205205
in_channels=3,
206206
out_channels=5,
@@ -219,7 +219,7 @@ def test_motion_conv2d_parity(self):
219219

220220
def test_motion_linear_parity(self):
221221
hf_module = HFMotionLinear(7, 5, use_activation=True).eval()
222-
max_module = FlaxMotionLinear(rngs=self.rngs, in_dim=7, out_dim=5, use_activation=True)
222+
max_module = MotionLinear(rngs=self.rngs, in_dim=7, out_dim=5, use_activation=True)
223223
copy_motion_linear_params(max_module, hf_module)
224224

225225
inputs = torch.randn(4, 7)
@@ -230,7 +230,7 @@ def test_motion_linear_parity(self):
230230

231231
def test_motion_encoder_resblock_parity(self):
232232
hf_module = HFMotionEncoderResBlock(8, 10).eval()
233-
max_module = FlaxMotionEncoderResBlock(rngs=self.rngs, in_channels=8, out_channels=10)
233+
max_module = MotionEncoderResBlock(rngs=self.rngs, in_channels=8, out_channels=10)
234234
copy_motion_encoder_resblock_params(max_module, hf_module)
235235

236236
inputs = torch.randn(2, 8, 8, 8)
@@ -249,7 +249,7 @@ def test_motion_encoder_parity(self):
249249
"channels": {"4": 8, "8": 8, "16": 8},
250250
}
251251
hf_module = HFWanAnimateMotionEncoder(**cfg).eval()
252-
max_module = FlaxWanAnimateMotionEncoder(rngs=self.rngs, **cfg)
252+
max_module = WanAnimateMotionEncoder(rngs=self.rngs, **cfg)
253253
copy_motion_encoder_params(max_module, hf_module)
254254

255255
inputs = torch.randn(3, 3, 4, 4)
@@ -260,7 +260,7 @@ def test_motion_encoder_parity(self):
260260

261261
def test_face_encoder_parity(self):
262262
hf_module = HFWanAnimateFaceEncoder(in_dim=8, out_dim=12, hidden_dim=16, num_heads=2).eval()
263-
max_module = FlaxWanAnimateFaceEncoder(rngs=self.rngs, in_dim=8, out_dim=12, hidden_dim=16, num_heads=2)
263+
max_module = WanAnimateFaceEncoder(rngs=self.rngs, in_dim=8, out_dim=12, hidden_dim=16, num_heads=2)
264264
copy_face_encoder_params(max_module, hf_module)
265265

266266
inputs = torch.randn(2, 7, 8)
@@ -271,7 +271,7 @@ def test_face_encoder_parity(self):
271271

272272
def test_face_block_cross_attention_parity(self):
273273
hf_module = HFWanAnimateFaceBlockCrossAttention(dim=12, heads=3, dim_head=4, cross_attention_dim_head=4).eval()
274-
max_module = FlaxWanAnimateFaceBlockCrossAttention(
274+
max_module = WanAnimateFaceBlockCrossAttention(
275275
rngs=self.rngs,
276276
dim=12,
277277
heads=3,

0 commit comments

Comments
 (0)