Skip to content

Commit 2c361e4

Browse files
committed
Fix sharding attempt #1 -> 128 TPU test needed
1 parent b4f9507 commit 2c361e4

3 files changed

Lines changed: 10 additions & 3 deletions

File tree

src/maxdiffusion/configs/base14.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ mask_padding_tokens: True
5858
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
5959
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
6060
# in cross attention q.
61-
attention_sharding_uniform: True
61+
attention_sharding_uniform: False
6262
flash_block_sizes: {}
6363
# GroupNorm groups
6464
norm_num_groups: 32

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ mask_padding_tokens: True
7070
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
7171
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7272
# in cross attention q.
73-
attention_sharding_uniform: True
73+
attention_sharding_uniform: False
7474
dropout: 0.1
7575

7676
flash_block_sizes: {
@@ -165,7 +165,8 @@ mesh_axes: ['data', 'fsdp', 'tensor']
165165
logical_axis_rules: [
166166
['batch', 'data'],
167167
['activation_batch', 'data'],
168-
['activation_self_attn_heads', ['fsdp', 'tensor']],
168+
['activation_self_attn_q_length', 'fsdp'],
169+
['activation_self_attn_kv_length', 'fsdp'],
169170
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
170171
['activation_length', 'fsdp'],
171172
['activation_heads', 'tensor'],

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ def __init__(
187187

188188
def __call__(self, x: jax.Array) -> jax.Array:
189189
x = self.proj(x)
190+
jax.debug.print("ApproximateGELU activation shape: {shape}", shape=x.shape)
191+
jax.debug.inspect_array_sharding(x, callback=print)
190192
return nnx.gelu(x)
191193

192194

@@ -245,7 +247,11 @@ def conditional_named_scope(self, name: str):
245247

246248
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
247249
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
250+
jax.debug.print(f"MLP input shape: {{shape}}", shape=hidden_states.shape)
251+
jax.debug.inspect_array_sharding(hidden_states, callback=print)
248252
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
253+
jax.debug.print(f"MLP intermediate activation shape: {{shape}}", shape=hidden_states.shape)
254+
jax.debug.inspect_array_sharding(hidden_states, callback=print)
249255
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250256
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251257
with self.conditional_named_scope("mlp_down_proj"):

0 commit comments

Comments
 (0)