Skip to content

Commit 8861b92

Browse files
committed
fix in audio vae
1 parent 0a123e5 commit 8861b92

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2_audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def __init__(
506506
)
507507
curr_res = curr_res // 2
508508

509-
self.down_stages.append({"blocks": stage_blocks, "attns": stage_attns, "downsample": downsample})
509+
self.down_stages.append(nnx.Dict({"blocks": stage_blocks, "attns": stage_attns, "downsample": downsample}))
510510

511511
self.mid_block1 = FlaxLTX2AudioResnetBlock(
512512
block_in,
@@ -671,7 +671,7 @@ def __init__(
671671
)
672672
curr_res *= 2
673673

674-
self.up_stages.append({"blocks": stage_blocks, "attns": stage_attns, "upsample": upsample})
674+
self.up_stages.append(nnx.Dict({"blocks": stage_blocks, "attns": stage_attns, "upsample": upsample}))
675675

676676
if self.norm_type == "group":
677677
self.norm_out = nnx.GroupNorm(num_groups=32, num_channels=block_in, epsilon=1e-6, dtype=dtype, rngs=rngs)

0 commit comments

Comments
 (0)