Skip to content

Commit 2bbf289

Browse files
add fsdp/tp sharding of weights.
1 parent bee57ba commit 2bbf289

4 files changed

Lines changed: 27 additions & 18 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ mesh_axes: ['data', 'fsdp', 'tensor']
126126
# conv_out : conv.shape[-1] weight
127127
logical_axis_rules: [
128128
['batch', 'data'],
129-
['activation_length', 'fsdp'],
129+
['activation_length', ['fsdp','tensor']],
130+
['activation_kv_length', 'fsdp'],
130131
['activation_heads', 'tensor'],
131132
['activation_batch', 'data'],
132-
['mlp','tensor'],
133-
['embed','fsdp'],
133+
#['mlp','tensor'],
134+
['embed',['fsdp', 'tensor']],
135+
['heads', ['tensor', 'fsdp']],
134136
['norm', 'tensor'],
135137
['conv_batch', ['data','fsdp']],
136138
['out_channels', 'tensor'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def _tpu_flash_attention(
187187
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
188188
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
189189
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
190-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191-
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192-
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
190+
#flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191+
#axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192+
#named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
193193

194194
shard_head_size = mesh.shape["tensor"]
195195

@@ -210,7 +210,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
210210

211211
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
212212
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
213-
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
213+
#segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
214214

215215
@functools.partial(
216216
shard_map.shard_map,
@@ -219,7 +219,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
219219
q_axis_names,
220220
kv_axis_names,
221221
kv_axis_names,
222-
segment_axis_names_splash_kernel,
222+
None,
223223
),
224224
out_specs=q_axis_names,
225225
check_rep=False,
@@ -686,7 +686,7 @@ def __init__(
686686
dtype=dtype,
687687
param_dtype=weights_dtype,
688688
precision=precision,
689-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
689+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
690690
)
691691

692692
self.key = nnx.Linear(
@@ -697,7 +697,7 @@ def __init__(
697697
dtype=dtype,
698698
param_dtype=weights_dtype,
699699
precision=precision,
700-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
700+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
701701
)
702702

703703
self.value = nnx.Linear(
@@ -708,7 +708,7 @@ def __init__(
708708
dtype=dtype,
709709
param_dtype=weights_dtype,
710710
precision=precision,
711-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
711+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
712712
)
713713

714714
self.proj_attn = nnx.Linear(

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
"mlp",
100100
),
101101
),
102-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
102+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
103103
)
104104

105105
if cond_proj_dim is not None:
@@ -131,7 +131,7 @@ def __init__(
131131
"embed",
132132
),
133133
),
134-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
134+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
135135
)
136136

137137
if post_act_fn is None:
@@ -275,11 +275,11 @@ def __init__(
275275
kernel_init=nnx.with_partitioning(
276276
nnx.initializers.xavier_uniform(),
277277
(
278-
"embed",
279278
"mlp",
279+
"embed",
280280
),
281281
),
282-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
282+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
283283
)
284284
self.act_1 = get_activation(act_fn)
285285

@@ -294,11 +294,11 @@ def __init__(
294294
kernel_init=nnx.with_partitioning(
295295
nnx.initializers.xavier_uniform(),
296296
(
297-
"mlp",
298297
"embed",
298+
"mlp",
299299
),
300300
),
301-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
301+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
302302
)
303303

304304
def __call__(self, caption):

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
"mlp",
124124
),
125125
),
126-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
126+
#bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
127127
)
128128
self.text_embedder = NNXPixArtAlphaTextProjection(
129129
rngs=rngs,
@@ -170,6 +170,13 @@ def __init__(
170170
dtype=dtype,
171171
param_dtype=weights_dtype,
172172
precision=precision,
173+
kernel_init=nnx.with_partitioning(
174+
nnx.initializers.xavier_uniform(),
175+
(
176+
"embed",
177+
"mlp",
178+
),
179+
),
173180
)
174181

175182
def __call__(self, x: jax.Array) -> jax.Array:

0 commit comments

Comments
 (0)