Skip to content

Commit 3f6eb05

Browse files
single forward loop.
1 parent 125dcfa commit 3f6eb05

4 files changed

Lines changed: 36 additions & 17 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,17 @@ from_pt: True
5252
split_head_dim: True
5353
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5454

55-
flash_block_sizes: {}
55+
#flash_block_sizes: {}
56+
flash_block_sizes: {
57+
"block_q" : 2048,
58+
"block_kv_compute" : 2048,
59+
"block_kv" : 2048,
60+
"block_q_dkv" : 2048,
61+
"block_kv_dkv" : 2048,
62+
"block_kv_dkv_compute" : 2048,
63+
"block_q_dq" : 2048,
64+
"block_kv_dq" : 2048
65+
}
5666
# GroupNorm groups
5767
norm_num_groups: 32
5868

@@ -112,7 +122,7 @@ mesh_axes: ['data', 'fsdp', 'tensor']
112122
# conv_out : conv.shape[-1] weight
113123
logical_axis_rules: [
114124
['batch', 'data'],
115-
#['activation_heads', 'fsdp'],
125+
['activation_heads', 'tensor'],
116126
['activation_length', 'fsdp'],
117127
#['activation_heads', 'fsdp'],
118128
#['activation_heads', 'fsdp'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _tpu_flash_attention(
162162
) -> jax.Array:
163163
"""TPU Flash Attention"""
164164

165-
max_block_size = 768#1024 if dtype == jnp.bfloat16 else 512
165+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
166166
if flash_block_sizes:
167167
block_sizes = flash_block_sizes
168168
else:
@@ -205,8 +205,8 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
205205
)
206206
return splash_kernel
207207

208-
shard_head_size = 1
209-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2]))
208+
shard_head_size = mesh.shape["tensor"]
209+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210210
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
211211
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
212212
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
@@ -223,7 +223,10 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
223223
check_rep=False
224224
)
225225
def wrap_flash_attention(query, key, value, splash_kernel):
226+
#full_k = jax.lax.all_to_all(key, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True)
227+
#full_v = jax.lax.all_to_all(value, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True)
226228
attention_output = jax.vmap(splash_kernel)(query, key, value)
229+
#attention_output = jax.vmap(splash_kernel)(query, full_k, full_v)
227230
return attention_output
228231

229232
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,11 +469,18 @@ def __call__(
469469

470470
if encoder_hidden_states_image is not None:
471471
raise NotImplementedError("img2vid is not yet implemented.")
472+
def skip_block_true(hidden_states):
473+
split_bs = hidden_states.shape[0] // 2
474+
prev_neg_hidden_states = hidden_states[split_bs:]
475+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
476+
hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0)
477+
return hidden_states
478+
472479
for block_idx, block in enumerate(self.blocks):
473480
should_skip_block = slg_mask[block_idx] & is_uncond
474481
hidden_states = jax.lax.cond(
475482
should_skip_block,
476-
lambda hs: hs, # If true, pass through original hidden_states (skip block)
483+
lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block)
477484
lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb),
478485
hidden_states,
479486
)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,17 @@ def run_inference(
470470
slg_end: float = 1.0,
471471
):
472472
do_classifier_free_guidance = guidance_scale > 1.0
473+
if do_classifier_free_guidance:
474+
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
473475
for step in range(num_inference_steps):
474476
slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_)
475477
if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps):
476478
slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True)
477479
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
480+
# get original batch size before concat in case of cfg.
481+
bsz = latents.shape[0]
482+
if do_classifier_free_guidance:
483+
latents = jnp.concatenate([latents] * 2)
478484
timestep = jnp.broadcast_to(t, latents.shape[0])
479485

480486
noise_pred = transformer_forward_pass(
@@ -484,21 +490,14 @@ def run_inference(
484490
latents,
485491
timestep,
486492
prompt_embeds,
487-
is_uncond=jnp.array(False, dtype=jnp.bool_),
493+
is_uncond=jnp.array(True, dtype=jnp.bool_),
488494
slg_mask=slg_mask,
489495
)
490496

491497
if do_classifier_free_guidance:
492-
noise_uncond = transformer_forward_pass(
493-
graphdef,
494-
sharded_state,
495-
rest_of_state,
496-
latents,
497-
timestep,
498-
negative_prompt_embeds,
499-
is_uncond=jnp.array(True, dtype=jnp.bool_),
500-
slg_mask=slg_mask,
501-
)
498+
noise_uncond = noise_pred[bsz:]
499+
noise_pred = noise_pred[:bsz]
502500
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
501+
latents = latents[:bsz]
503502
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
504503
return latents

0 commit comments

Comments
 (0)