@@ -384,9 +384,9 @@ def __init__(
384384 dtype = self .dtype ,
385385 )
386386
387- self .ups = nnx .List ()
387+ self .upsamplers = nnx .List ()
388388 for i , (stride , kernel_size ) in enumerate (zip (upsample_rates , upsample_kernel_sizes )):
389- self .ups .append (
389+ self .upsamplers .append (
390390 nnx .ConvTranspose (
391391 in_features = upsample_initial_channel // (2 ** i ),
392392 out_features = upsample_initial_channel // (2 ** (i + 1 )),
@@ -398,11 +398,11 @@ def __init__(
398398 )
399399 )
400400
401- self .resblocks = nnx .List ()
401+ self .resnets = nnx .List ()
402402 for i in range (len (upsample_rates )):
403403 ch = upsample_initial_channel // (2 ** (i + 1 ))
404404 for kernel_size , dilations in zip (resblock_kernel_sizes , resblock_dilation_sizes ):
405- self .resblocks .append (AMPBlock1 (ch , kernel_size , dilations , activation = activation , rngs = rngs ))
405+ self .resnets .append (AMPBlock1 (ch , kernel_size , dilations , activation = activation , rngs = rngs ))
406406
407407 final_channels = upsample_initial_channel // (2 ** len (upsample_rates ))
408408 self .act_out = Activation1d (final_channels , SnakeBeta (final_channels , rngs = rngs ), rngs = rngs )
@@ -429,14 +429,14 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
429429 hidden_states = self .conv_in (hidden_states )
430430
431431 for i in range (self .num_upsamples ):
432- hidden_states = self .ups [i ](hidden_states )
432+ hidden_states = self .upsamplers [i ](hidden_states )
433433
434434 start = i * self .num_kernels
435435 end = (i + 1 ) * self .num_kernels
436436
437437 res_sum = 0.0
438438 for j in range (start , end ):
439- res_sum = res_sum + self .resblocks [j ](hidden_states )
439+ res_sum = res_sum + self .resnets [j ](hidden_states )
440440
441441 hidden_states = res_sum / self .num_kernels
442442
0 commit comments