Skip to content

Commit d5b6da3

Browse files
update shardings in attn.
1 parent a8f80b7 commit d5b6da3

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,10 @@ logical_axis_rules: [
132132
['mlp','tensor'],
133133
['embed','fsdp'],
134134
['heads', 'tensor'],
135-
['norm', 'fsdp'],
135+
['norm', 'tensor'],
136136
['conv_batch', ['data','fsdp']],
137137
['out_channels', 'tensor'],
138-
['conv_out', 'fsdp'],
139-
['conv_in', 'fsdp']
138+
['conv_in', 'fsdp'],
140139
]
141140
data_sharding: [['data', 'fsdp', 'tensor']]
142141

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def __init__(
686686
dtype=dtype,
687687
param_dtype=weights_dtype,
688688
precision=precision,
689+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
689690
)
690691

691692
self.key = nnx.Linear(
@@ -696,6 +697,7 @@ def __init__(
696697
dtype=dtype,
697698
param_dtype=weights_dtype,
698699
precision=precision,
700+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
699701
)
700702

701703
self.value = nnx.Linear(
@@ -706,6 +708,7 @@ def __init__(
706708
dtype=dtype,
707709
param_dtype=weights_dtype,
708710
precision=precision,
711+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
709712
)
710713

711714
self.proj_attn = nnx.Linear(

0 commit comments

Comments
 (0)