@@ -128,7 +128,13 @@ def __init__(
128128 upsample_factors : Sequence [int ] = (6 , 5 , 2 , 2 , 2 ),
129129 resnet_kernel_sizes : Sequence [int ] = (3 , 7 , 11 ),
130130 resnet_dilations : Sequence [Sequence [int ]] = ((1 , 3 , 5 ), (1 , 3 , 5 ), (1 , 3 , 5 )),
131+ act_fn : str = "leaky_relu" ,
131132 leaky_relu_negative_slope : float = 0.1 ,
133+ antialias : bool = False ,
134+ antialias_ratio : int = 2 ,
135+ antialias_kernel_size : int = 12 ,
136+ final_act_fn : Optional [str ] = None ,
137+ final_bias : bool = False ,
132138 # output_sampling_rate is unused in model structure but kept for config compat
133139 output_sampling_rate : int = 24000 ,
134140 * ,
@@ -141,6 +147,8 @@ def __init__(
141147 self .total_upsample_factor = math .prod (upsample_factors )
142148 self .negative_slope = leaky_relu_negative_slope
143149 self .act_fn = act_fn
150+ self .final_act_fn = final_act_fn
151+ self .final_bias = final_bias
144152 self .dtype = dtype
145153
146154 if self .num_upsample_layers != len (upsample_factors ):
@@ -219,6 +227,7 @@ def __init__(
219227 kernel_size = (7 ,),
220228 strides = (1 ,),
221229 padding = "SAME" ,
230+ use_bias = self .final_bias ,
222231 rngs = rngs ,
223232 dtype = self .dtype ,
224233 )
0 commit comments