@@ -411,19 +411,19 @@ def __init__(self, rngs: nnx.Rngs, dim: int = 128, z_dim: int = 4, dim_mult=[1,
411411 self .conv_in = WanCausalConv3d (rngs = rngs , in_channels = 3 , out_channels = dims [0 ], kernel_size = 3 , padding = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
412412
413413 # We need a structured way to access blocks for cache init, so we separate them
414- self . down_blocks_layers = []
414+ down_blocks_layers = []
415415
416416 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
417417 for _ in range (num_res_blocks ):
418- self . down_blocks_layers .append (WanResidualBlock (in_dim = in_dim , out_dim = out_dim , dropout = dropout , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
418+ down_blocks_layers .append (WanResidualBlock (in_dim = in_dim , out_dim = out_dim , dropout = dropout , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
419419 if scale in attn_scales :
420- self . down_blocks_layers .append (WanAttentionBlock (dim = out_dim , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
420+ down_blocks_layers .append (WanAttentionBlock (dim = out_dim , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
421421 in_dim = out_dim
422422 if i != len (dim_mult ) - 1 :
423423 mode = "downsample3d" if temperal_downsample [i ] else "downsample2d"
424- self . down_blocks_layers .append (WanResample (out_dim , mode = mode , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
424+ down_blocks_layers .append (WanResample (out_dim , mode = mode , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
425425 scale /= 2.0
426- self .down_blocks = nnx .data (self . down_blocks_layers )
426+ self .down_blocks = nnx .data (down_blocks_layers )
427427
428428 self .mid_block = WanMidBlock (dim = out_dim , rngs = rngs , dropout = dropout , non_linearity = non_linearity , num_layers = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
429429 self .norm_out = WanRMS_norm (out_dim , channel_first = False , images = False , rngs = rngs )
@@ -489,14 +489,14 @@ def __init__(self, rngs: nnx.Rngs, dim: int = 128, z_dim: int = 4, dim_mult: Lis
489489 self .conv_in = WanCausalConv3d (rngs = rngs , in_channels = z_dim , out_channels = dims [0 ], kernel_size = 3 , padding = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
490490 self .mid_block = WanMidBlock (dim = dims [0 ], rngs = rngs , dropout = dropout , non_linearity = non_linearity , num_layers = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
491491
492- self . up_blocks = []
492+ up_blocks = []
493493 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
494494 if i > 0 : in_dim = in_dim // 2
495495 upsample_mode = None
496496 if i != len (dim_mult ) - 1 :
497497 upsample_mode = "upsample3d" if temperal_upsample [i ] else "upsample2d"
498- self . up_blocks .append (WanUpBlock (in_dim = in_dim , out_dim = out_dim , num_res_blocks = num_res_blocks , dropout = dropout , upsample_mode = upsample_mode , non_linearity = non_linearity , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
499- self .up_blocks = nnx .data (self . up_blocks )
498+ up_blocks .append (WanUpBlock (in_dim = in_dim , out_dim = out_dim , num_res_blocks = num_res_blocks , dropout = dropout , upsample_mode = upsample_mode , non_linearity = non_linearity , rngs = rngs , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision ))
499+ self .up_blocks = nnx .data (up_blocks )
500500
501501 self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs , channel_first = False )
502502 self .conv_out = WanCausalConv3d (rngs = rngs , in_channels = out_dim , out_channels = 3 , kernel_size = 3 , padding = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
0 commit comments