Skip to content

Commit 76f5bd7

Browse files
committed
test different sharding strategy
1 parent 043f826 commit 76f5bd7

4 files changed

Lines changed: 39 additions & 37 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,14 @@ logical_axis_rules: [
138138
['mlp','tensor'],
139139
['embed','fsdp'],
140140
['heads', 'tensor'],
141-
['norm', 'tensor'],
141+
["qkv", ["tensor", "fsdp"]],
142+
["proj_out", ["tensor", "fsdp"]],
143+
["timestep_ln1", ["tensor", "fsdp"]],
144+
["timestep_ln2", ["tensor", "fsdp"]],
145+
["text_proj_ln1", ["tensor", "fsdp"]],
146+
["text_proj_ln2", ["tensor", "fsdp"]],
147+
["ffn_lin1", ["tensor", "fsdp"]],
148+
["ffn_lin2", ["tensor", "fsdp"]],
142149
['conv_batch', ['data','fsdp']],
143150
['out_channels', 'tensor'],
144151
['conv_out', 'fsdp'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def __init__(
734734
# None axes corresponds to the stacked weights across all blocks
735735
# because of the use of nnx.vmap and nnx.scan.
736736
# Dims are [num_blocks, embed, heads]
737-
kernel_axes = ("embed", "heads")
737+
kernel_axes = (None, "qkv")
738738
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
739739

740740
self.query = nnx.Linear(
@@ -747,7 +747,7 @@ def __init__(
747747
precision=precision,
748748
bias_init=nnx.with_partitioning(
749749
nnx.initializers.zeros,
750-
("embed",),
750+
("qkv"),
751751
),
752752
)
753753

@@ -761,7 +761,7 @@ def __init__(
761761
precision=precision,
762762
bias_init=nnx.with_partitioning(
763763
nnx.initializers.zeros,
764-
("embed",),
764+
("qkv",),
765765
),
766766
)
767767

@@ -775,22 +775,19 @@ def __init__(
775775
precision=precision,
776776
bias_init=nnx.with_partitioning(
777777
nnx.initializers.zeros,
778-
("embed",),
778+
("qkv",),
779779
),
780780
)
781781

782782
self.proj_attn = nnx.Linear(
783783
rngs=rngs,
784784
in_features=self.inner_dim,
785785
out_features=self.inner_dim,
786-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
786+
kernel_init=nnx.with_partitioning(
787+
nnx.initializers.lecun_normal(), ("proj_out", None)),
787788
dtype=dtype,
788789
param_dtype=weights_dtype,
789790
precision=precision,
790-
bias_init=nnx.with_partitioning(
791-
nnx.initializers.zeros,
792-
("heads",),
793-
),
794791
)
795792

796793
self.drop_out = nnx.Dropout(dropout)
@@ -803,21 +800,13 @@ def __init__(
803800
rngs=rngs,
804801
epsilon=eps,
805802
dtype=dtype,
806-
scale_init=nnx.with_partitioning(
807-
nnx.initializers.ones,
808-
("norm",),
809-
),
810803
param_dtype=weights_dtype,
811804
)
812805

813806
self.norm_k = nnx.RMSNorm(
814807
num_features=self.inner_dim,
815808
rngs=rngs,
816809
dtype=dtype,
817-
scale_init=nnx.with_partitioning(
818-
nnx.initializers.ones,
819-
("norm",),
820-
),
821810
param_dtype=weights_dtype,
822811
)
823812

@@ -845,8 +834,6 @@ def __call__(
845834
deterministic: bool = True,
846835
rngs: nnx.Rngs = None,
847836
) -> jax.Array:
848-
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
849-
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
850837
dtype = hidden_states.dtype
851838
if encoder_hidden_states is None:
852839
encoder_hidden_states = hidden_states
@@ -855,6 +842,11 @@ def __call__(
855842
key_proj = self.key(encoder_hidden_states)
856843
value_proj = self.value(encoder_hidden_states)
857844

845+
query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", ("tensor", "fsdp"), None))
846+
key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", ("tensor", "fsdp"), None))
847+
value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", ("tensor", "fsdp"), None))
848+
849+
858850
if self.qk_norm:
859851
query_proj = self.norm_q(query_proj)
860852
key_proj = self.norm_k(key_proj)

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def __init__(
9595
kernel_init=nnx.with_partitioning(
9696
nnx.initializers.xavier_uniform(),
9797
(
98-
"embed",
99-
"mlp",
98+
"timestep_ln1",
99+
None,
100100
),
101101
),
102-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
102+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("timestep_ln1",),)
103103
)
104104

105105
if cond_proj_dim is not None:
@@ -127,11 +127,10 @@ def __init__(
127127
kernel_init=nnx.with_partitioning(
128128
nnx.initializers.xavier_uniform(),
129129
(
130-
"mlp",
131-
"embed",
130+
None,
131+
"timestep_ln2"
132132
),
133133
),
134-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
135134
)
136135

137136
if post_act_fn is None:
@@ -275,11 +274,11 @@ def __init__(
275274
kernel_init=nnx.with_partitioning(
276275
nnx.initializers.xavier_uniform(),
277276
(
278-
"embed",
279-
"mlp",
277+
"text_proj_ln1",
278+
None,
280279
),
281280
),
282-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
281+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("text_proj_ln1",)),
283282
)
284283
self.act_1 = get_activation(act_fn)
285284

@@ -294,11 +293,10 @@ def __init__(
294293
kernel_init=nnx.with_partitioning(
295294
nnx.initializers.xavier_uniform(),
296295
(
297-
"mlp",
298-
"embed",
296+
None,
297+
"text_proj_ln2"
299298
),
300299
),
301-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
302300
)
303301

304302
def __call__(self, caption):

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ def __init__(
176176
kernel_init=nnx.with_partitioning(
177177
nnx.initializers.xavier_uniform(),
178178
(
179-
"mlp",
180-
"embed",
179+
"ffn_lin1",
180+
None
181181
),
182182
),
183-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
183+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("ffn_lin1",)),
184184
)
185185

186186
def __call__(self, x: jax.Array) -> jax.Array:
@@ -229,8 +229,8 @@ def __init__(
229229
kernel_init=nnx.with_partitioning(
230230
nnx.initializers.xavier_uniform(),
231231
(
232-
"embed",
233-
"mlp",
232+
None,
233+
"ffn_lin2",
234234
),
235235
),
236236
)
@@ -485,6 +485,9 @@ def __call__(
485485
deterministic: bool = True,
486486
rngs: nnx.Rngs = None,
487487
) -> Union[jax.Array, Dict[str, jax.Array]]:
488+
489+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data",))
490+
488491
batch_size, _, num_frames, height, width = hidden_states.shape
489492
p_t, p_h, p_w = self.config.patch_size
490493
post_patch_num_frames = num_frames // p_t
@@ -497,6 +500,8 @@ def __call__(
497500
hidden_states = self.patch_embedding(hidden_states)
498501
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
499502

503+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", ("tensor", "fsdp"), None))
504+
500505
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
501506
timestep, encoder_hidden_states, encoder_hidden_states_image
502507
)

0 commit comments

Comments
 (0)