Skip to content

Commit 433b2d4

Browse files
committed
Fix errors
1 parent 211f58f commit 433b2d4

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -411,19 +411,19 @@ def __init__(self, rngs: nnx.Rngs, dim: int = 128, z_dim: int = 4, dim_mult=[1,
411411
self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=3, out_channels=dims[0], kernel_size=3, padding=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
412412

413413
# We need a structured way to access blocks for cache init, so we separate them
414-
self.down_blocks_layers = []
414+
down_blocks_layers = []
415415

416416
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
417417
for _ in range(num_res_blocks):
418-
self.down_blocks_layers.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
418+
down_blocks_layers.append(WanResidualBlock(in_dim=in_dim, out_dim=out_dim, dropout=dropout, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
419419
if scale in attn_scales:
420-
self.down_blocks_layers.append(WanAttentionBlock(dim=out_dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
420+
down_blocks_layers.append(WanAttentionBlock(dim=out_dim, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
421421
in_dim = out_dim
422422
if i != len(dim_mult) - 1:
423423
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
424-
self.down_blocks_layers.append(WanResample(out_dim, mode=mode, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
424+
down_blocks_layers.append(WanResample(out_dim, mode=mode, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
425425
scale /= 2.0
426-
self.down_blocks = nnx.data(self.down_blocks_layers)
426+
self.down_blocks = nnx.data(down_blocks_layers)
427427

428428
self.mid_block = WanMidBlock(dim=out_dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
429429
self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs)
@@ -489,14 +489,14 @@ def __init__(self, rngs: nnx.Rngs, dim: int = 128, z_dim: int = 4, dim_mult: Lis
489489
self.conv_in = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
490490
self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
491491

492-
self.up_blocks = []
492+
up_blocks = []
493493
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
494494
if i > 0: in_dim = in_dim // 2
495495
upsample_mode = None
496496
if i != len(dim_mult) - 1:
497497
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
498-
self.up_blocks.append(WanUpBlock(in_dim=in_dim, out_dim=out_dim, num_res_blocks=num_res_blocks, dropout=dropout, upsample_mode=upsample_mode, non_linearity=non_linearity, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
499-
self.up_blocks = nnx.data(self.up_blocks)
498+
up_blocks.append(WanUpBlock(in_dim=in_dim, out_dim=out_dim, num_res_blocks=num_res_blocks, dropout=dropout, upsample_mode=upsample_mode, non_linearity=non_linearity, rngs=rngs, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision))
499+
self.up_blocks = nnx.data(up_blocks)
500500

501501
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)
502502
self.conv_out = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)

0 commit comments

Comments
 (0)