Skip to content

Commit 20ca585

Browse files
committed
match new flax version
1 parent 8fdf3c2 commit 20ca585

3 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,8 @@ def __init__(
795795

796796
self.drop_out = nnx.Dropout(dropout)
797797

798-
self.norm_q = None
799-
self.norm_k = None
798+
self.norm_q = nnx.data(None)
799+
self.norm_k = nnx.data(None)
800800
if qk_norm is not None:
801801
self.norm_q = nnx.RMSNorm(
802802
num_features=self.inner_dim,

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def __init__(
225225
):
226226
self.dim = dim
227227
self.mode = mode
228-
self.time_conv = None
228+
self.time_conv = nnx.data(None)
229229

230230
if mode == "upsample2d":
231231
self.resample = nnx.Sequential(
@@ -554,8 +554,8 @@ def __init__(
554554
precision=precision,
555555
)
556556
)
557-
self.attentions = attentions
558-
self.resnets = resnets
557+
self.attentions = nnx.data(attentions)
558+
self.resnets = nnx.data(resnets)
559559

560560
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
561561
x = self.resnets[0](x, feat_cache, feat_idx)
@@ -601,10 +601,10 @@ def __init__(
601601
)
602602
)
603603
current_dim = out_dim
604-
self.resnets = resnets
604+
self.resnets = nnx.data(resnets)
605605

606606
# Add upsampling layer if needed.
607-
self.upsamplers = None
607+
self.upsamplers = nnx.data(None)
608608
if upsample_mode is not None:
609609
self.upsamplers = [
610610
WanResample(
@@ -710,6 +710,7 @@ def __init__(
710710
)
711711
)
712712
scale /= 2.0
713+
self.down_blocks = nnx.data(self.down_blocks)
713714

714715
# middle_blocks
715716
self.mid_block = WanMidBlock(
@@ -873,6 +874,7 @@ def __init__(
873874
# Update scale for next iteration
874875
if upsample_mode is not None:
875876
scale *= 2.0
877+
self.up_blocks = nnx.data(self.up_blocks)
876878

877879
# output blocks
878880
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __init__(
209209
inner_dim = int(dim * mult)
210210
dim_out = dim_out if dim_out is not None else dim
211211

212-
self.act_fn = None
212+
self.act_fn = nnx.data(None)
213213
if activation_fn == "gelu-approximate":
214214
self.act_fn = ApproximateGELU(
215215
rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision

0 commit comments

Comments
 (0)