Skip to content

Commit 44224c2

Browse files
committed
linting.
1 parent 0df5659 commit 44224c2

1 file changed

Lines changed: 35 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,13 @@ def __init__(
688688
dtype=dtype,
689689
param_dtype=weights_dtype,
690690
precision=precision,
691-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
691+
bias_init=nnx.with_partitioning(
692+
nnx.initializers.zeros,
693+
(
694+
None,
695+
"embed",
696+
),
697+
),
692698
)
693699

694700
self.key = nnx.Linear(
@@ -699,7 +705,13 @@ def __init__(
699705
dtype=dtype,
700706
param_dtype=weights_dtype,
701707
precision=precision,
702-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
708+
bias_init=nnx.with_partitioning(
709+
nnx.initializers.zeros,
710+
(
711+
None,
712+
"embed",
713+
),
714+
),
703715
)
704716

705717
self.value = nnx.Linear(
@@ -710,7 +722,13 @@ def __init__(
710722
dtype=dtype,
711723
param_dtype=weights_dtype,
712724
precision=precision,
713-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
725+
bias_init=nnx.with_partitioning(
726+
nnx.initializers.zeros,
727+
(
728+
None,
729+
"embed",
730+
),
731+
),
714732
)
715733

716734
self.proj_attn = nnx.Linear(
@@ -731,15 +749,27 @@ def __init__(
731749
rngs=rngs,
732750
epsilon=eps,
733751
dtype=dtype,
734-
scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)),
752+
scale_init=nnx.with_partitioning(
753+
nnx.initializers.ones,
754+
(
755+
None,
756+
"norm",
757+
),
758+
),
735759
param_dtype=weights_dtype,
736760
)
737761

738762
self.norm_k = nnx.RMSNorm(
739763
num_features=self.inner_dim,
740764
rngs=rngs,
741765
dtype=dtype,
742-
scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)),
766+
scale_init=nnx.with_partitioning(
767+
nnx.initializers.ones,
768+
(
769+
None,
770+
"norm",
771+
),
772+
),
743773
param_dtype=weights_dtype,
744774
)
745775

0 commit comments

Comments
 (0)