Skip to content

Commit 3ef352f

Browse files
set sharding contraints to reduce ags.
1 parent 50d2fe7 commit 3ef352f

4 files changed

Lines changed: 35 additions & 33 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

23-
jax.config.update('jax_use_shardy_partitioner', True)
23+
jax.config.update("jax_use_shardy_partitioner", True)
24+
2425

2526
def run(config, pipeline=None, filename_prefix=""):
2627
print("seed: ", config.seed)

src/maxdiffusion/models/attention_flax.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def _reshape_batch_dim_to_heads(tensor, heads):
7676
head_size = heads
7777
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
7878
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
79-
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
80-
return tensor
79+
reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
80+
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
8181

8282

8383
def _reshape_heads_to_batch_dim(tensor, heads):
@@ -86,12 +86,12 @@ def _reshape_heads_to_batch_dim(tensor, heads):
8686
head_size = heads
8787
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
8888
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
89-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
89+
reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
9090
else:
9191
batch_size, head_size, seq_len, head_dim = tensor.shape
92-
tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)
92+
reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)
9393

94-
return tensor
94+
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
9595

9696

9797
def _reshape_heads_to_head_dim(tensor):
@@ -140,14 +140,15 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
140140
# 2. Ensure num_blocks is divisible by num_shards
141141
num_blocks = seq_len_padded_pre // flash_block_size
142142
if num_blocks % num_shards != 0:
143-
num_blocks += (num_shards - (num_blocks % num_shards))
143+
num_blocks += num_shards - (num_blocks % num_shards)
144144

145145
final_padded_len = num_blocks * flash_block_size
146146
seq_len_pad = final_padded_len - seq_len
147147

148148
if kv_size < 128 or seq_len_pad != 0:
149149
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
150-
tensor = jnp.pad(tensor, npad)
150+
padded_tensor = jnp.pad(tensor, npad)
151+
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "fsdp", "tensor"))
151152

152153
return tensor, kv_size, seq_len
153154

@@ -189,40 +190,38 @@ def _tpu_flash_attention(
189190
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
190191
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
191192
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
192-
193-
shard_head_size=mesh.shape['tensor']
193+
194+
shard_head_size = mesh.shape["tensor"]
194195

195196
@functools.partial(
196197
jax.jit,
197-
static_argnames=[
198-
"multi_head_mask",
199-
"shard_head_size"
200-
],
198+
static_argnames=["multi_head_mask", "shard_head_size"],
201199
)
202200
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
203201
splash_kernel = splash_attention_kernel.make_splash_mha(
204-
mask=multi_head_mask,
205-
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
206-
q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len
207-
block_sizes=block_sizes,
202+
mask=multi_head_mask,
203+
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
204+
q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len
205+
block_sizes=block_sizes,
208206
)
209207
return splash_kernel
210208

211209
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
212210
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
213211
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
214212
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
213+
215214
@functools.partial(
216-
shard_map.shard_map,
217-
mesh=mesh,
218-
in_specs=(
219-
q_axis_names,
220-
kv_axis_names,
221-
kv_axis_names,
222-
segment_axis_names_splash_kernel,
223-
),
224-
out_specs=q_axis_names,
225-
check_rep=False
215+
shard_map.shard_map,
216+
mesh=mesh,
217+
in_specs=(
218+
q_axis_names,
219+
kv_axis_names,
220+
kv_axis_names,
221+
segment_axis_names_splash_kernel,
222+
),
223+
out_specs=q_axis_names,
224+
check_rep=False,
226225
)
227226
def wrap_flash_attention(query, key, value, splash_kernel):
228227
attention_output = jax.vmap(splash_kernel)(query, key, value)
@@ -386,7 +385,9 @@ def _apply_attention(
386385
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
387386
)
388387
elif attention_kernel == "flash":
389-
return _tpu_flash_attention(query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype)
388+
return _tpu_flash_attention(
389+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
390+
)
390391
elif attention_kernel == "cudnn_flash_te":
391392
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
392393
else:
@@ -566,8 +567,8 @@ class AttentionOp(nn.Module):
566567
use_memory_efficient_attention: bool = False
567568
split_head_dim: bool = False
568569
float32_qk_product: bool = True
569-
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
570-
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
570+
axis_names_q: AxisNames = ((BATCH, HEAD, LENGTH, D_KV),)
571+
axis_names_kv: AxisNames = ((BATCH, HEAD, KV_LENGTH, D_KV),)
571572
flash_min_seq_length: int = 4096
572573
flash_block_sizes: BlockSizes = None
573574
dtype: DType = jnp.float32
@@ -775,7 +776,6 @@ def __call__(
775776
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
776777

777778
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
778-
attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None))
779779

780780
attn_output = attn_output.astype(dtype=dtype)
781781

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def __call__(
462462

463463
if encoder_hidden_states_image is not None:
464464
raise NotImplementedError("img2vid is not yet implemented.")
465+
465466
def skip_block_true(hidden_states):
466467
split_bs = hidden_states.shape[0] // 2
467468
prev_neg_hidden_states = hidden_states[split_bs:]

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def __call__(
438438
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
439439
latents = latents / latents_std + latents_mean
440440
latents = latents.astype(self.config.weights_dtype)
441-
441+
442442
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
443443
video = self.vae.decode(latents, self.vae_cache)[0]
444444

0 commit comments

Comments
 (0)