Skip to content

Commit 84d293c

Browse files
single forward pass for pos/neg emb.
1 parent 5494644 commit 84d293c

2 files changed

Lines changed: 21 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,19 @@ def __call__(
475475

476476
if encoder_hidden_states_image is not None:
477477
raise NotImplementedError("img2vid is not yet implemented.")
478+
479+
def skip_block_true(hidden_states):
480+
split_bs = hidden_states.shape[0] // 2
481+
prev_neg_hidden_states = hidden_states[split_bs:]
482+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
483+
hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0)
484+
return hidden_states
485+
478486
for block_idx, block in enumerate(self.blocks):
479487
should_skip_block = slg_mask[block_idx] & is_uncond
480488
hidden_states = jax.lax.cond(
481489
should_skip_block,
482-
lambda hs: hs, # If true, pass through original hidden_states (skip block)
490+
lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block)
483491
lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb),
484492
hidden_states,
485493
)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,13 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
199199
wan_vae = nnx.merge(graphdef, params)
200200
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
201201
# Shard
202-
with mesh:
202+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
203203
wan_vae = p_create_sharded_logical_model(model=wan_vae)
204204
return wan_vae, vae_cache
205205

206206
@classmethod
207207
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
208-
with mesh:
208+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
209209
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
210210
return wan_transformer
211211

@@ -468,11 +468,17 @@ def run_inference(
468468
slg_end: float = 1.0,
469469
):
470470
do_classifier_free_guidance = guidance_scale > 1.0
471+
if do_classifier_free_guidance:
472+
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
471473
for step in range(num_inference_steps):
472474
slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_)
473475
if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps):
474476
slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True)
475477
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
478+
# get original batch size before concat in case of cfg.
479+
bsz = latents.shape[0]
480+
if do_classifier_free_guidance:
481+
latents = jnp.concatenate([latents] * 2)
476482
timestep = jnp.broadcast_to(t, latents.shape[0])
477483

478484
noise_pred = transformer_forward_pass(
@@ -482,21 +488,14 @@ def run_inference(
482488
latents,
483489
timestep,
484490
prompt_embeds,
485-
is_uncond=jnp.array(False, dtype=jnp.bool_),
491+
is_uncond=jnp.array(True, dtype=jnp.bool_),
486492
slg_mask=slg_mask,
487493
)
488494

489495
if do_classifier_free_guidance:
490-
noise_uncond = transformer_forward_pass(
491-
graphdef,
492-
sharded_state,
493-
rest_of_state,
494-
latents,
495-
timestep,
496-
negative_prompt_embeds,
497-
is_uncond=jnp.array(True, dtype=jnp.bool_),
498-
slg_mask=slg_mask,
499-
)
496+
noise_uncond = noise_pred[bsz:]
497+
noise_pred = noise_pred[:bsz]
500498
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
499+
latents = latents[:bsz]
501500
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
502501
return latents

0 commit comments

Comments
 (0)