Skip to content

Commit e7bd680

Browse files
committed
Trying text_mask 6
1 parent d88dd43 commit e7bd680

3 files changed

Lines changed: 48 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def __call__(
100100
negative_prompt_embeds: jax.Array = None,
101101
vae_only: bool = False,
102102
):
103-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
103+
(latents, prompt_embeds, negative_prompt_embeds, text_attention_mask,
104+
negative_text_attention_mask, scheduler_state, num_frames) = self._prepare_model_inputs(
104105
prompt,
105106
negative_prompt,
106107
height,
@@ -139,6 +140,8 @@ def __call__(
139140
latents=latents,
140141
prompt_embeds=prompt_embeds,
141142
negative_prompt_embeds=negative_prompt_embeds,
143+
text_attention_mask=text_attention_mask,
144+
negative_text_attention_mask=negative_text_attention_mask,
142145
)
143146
latents = self._denormalize_latents(latents)
144147
return self._decode_latents_to_video(latents)
@@ -153,6 +156,8 @@ def run_inference_2_2(
153156
latents: jnp.array,
154157
prompt_embeds: jnp.array,
155158
negative_prompt_embeds: jnp.array,
159+
text_attention_mask: Optional[jax.Array],
160+
negative_text_attention_mask: Optional[jax.Array],
156161
guidance_scale_low: float,
157162
guidance_scale_high: float,
158163
boundary: int,
@@ -163,21 +168,27 @@ def run_inference_2_2(
163168
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
164169
if do_classifier_free_guidance:
165170
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
171+
if text_attention_mask is not None and negative_text_attention_mask is not None:
172+
encoder_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
173+
else:
174+
encoder_attention_mask = None
166175

167176
def low_noise_branch(operands):
168177
latents, timestep, prompt_embeds = operands
169178
return transformer_forward_pass(
170179
low_noise_graphdef, low_noise_state, low_noise_rest,
171180
latents, timestep, prompt_embeds,
172-
do_classifier_free_guidance, guidance_scale_low
181+
do_classifier_free_guidance, guidance_scale_low,
182+
encoder_attention_mask=encoder_attention_mask,
173183
)
174184

175185
def high_noise_branch(operands):
176186
latents, timestep, prompt_embeds = operands
177187
return transformer_forward_pass(
178188
high_noise_graphdef, high_noise_state, high_noise_rest,
179189
latents, timestep, prompt_embeds,
180-
do_classifier_free_guidance, guidance_scale_high
190+
do_classifier_free_guidance, guidance_scale_high,
191+
encoder_attention_mask=encoder_attention_mask,
181192
)
182193

183194
for step in range(num_inference_steps):

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def __call__(
171171
max_logging.log(f"Adjusted num_frames to: {num_frames}")
172172
num_frames = max(num_frames, 1)
173173

174-
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
174+
(prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask,
175+
image_embeds, effective_batch_size) = self._prepare_model_inputs_i2v(
175176
prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length,
176177
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
177178
)
@@ -230,9 +231,11 @@ def __call__(
230231
latents = p_run_inference(
231232
latents=latents,
232233
condition=condition,
233-
prompt_embeds=prompt_embeds,
234-
negative_prompt_embeds=negative_prompt_embeds,
235-
image_embeds=image_embeds,
234+
prompt_embeds=prompt_embeds,
235+
negative_prompt_embeds=negative_prompt_embeds,
236+
text_attention_mask=text_attention_mask,
237+
negative_text_attention_mask=negative_text_attention_mask,
238+
image_embeds=image_embeds,
236239
first_frame_mask=first_frame_mask,
237240
scheduler_state=scheduler_state,
238241
rng=inference_rng,
@@ -251,6 +254,8 @@ def run_inference_2_1_i2v(
251254
condition: jnp.array,
252255
prompt_embeds: jnp.array,
253256
negative_prompt_embeds: jnp.array,
257+
text_attention_mask: Optional[jax.Array],
258+
negative_text_attention_mask: Optional[jax.Array],
254259
image_embeds: jnp.array,
255260
guidance_scale: float,
256261
num_inference_steps: int,
@@ -263,6 +268,10 @@ def run_inference_2_1_i2v(
263268

264269
if do_classifier_free_guidance:
265270
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
271+
if text_attention_mask is not None and negative_text_attention_mask is not None:
272+
encoder_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
273+
else:
274+
encoder_attention_mask = None
266275
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
267276
condition = jnp.concatenate([condition] * 2)
268277
for step in range(num_inference_steps):
@@ -280,6 +289,7 @@ def run_inference_2_1_i2v(
280289
do_classifier_free_guidance=do_classifier_free_guidance,
281290
guidance_scale=guidance_scale,
282291
encoder_hidden_states_image=image_embeds,
292+
encoder_attention_mask=encoder_attention_mask,
283293
)
284294
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
285295
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False)

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def __call__(
165165
max_logging.log(f"Adjusted num_frames to: {num_frames}")
166166
num_frames = max(num_frames, 1)
167167

168-
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
168+
(prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask,
169+
image_embeds, effective_batch_size) = self._prepare_model_inputs_i2v(
169170
prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length,
170171
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
171172
)
@@ -224,6 +225,8 @@ def __call__(
224225
boundary=boundary_timestep,
225226
num_inference_steps=num_inference_steps,
226227
scheduler=self.scheduler,
228+
text_attention_mask=text_attention_mask,
229+
negative_text_attention_mask=negative_text_attention_mask,
227230
image_embeds=image_embeds,
228231
first_frame_mask=first_frame_mask,
229232
)
@@ -250,6 +253,8 @@ def run_inference_2_2_i2v(
250253
condition: jnp.array,
251254
prompt_embeds: jnp.array,
252255
negative_prompt_embeds: jnp.array,
256+
text_attention_mask: Optional[jax.Array],
257+
negative_text_attention_mask: Optional[jax.Array],
253258
image_embeds: jnp.array,
254259
first_frame_mask: Optional[jnp.array],
255260
guidance_scale_low: float,
@@ -261,14 +266,25 @@ def run_inference_2_2_i2v(
261266
rng: jax.Array,
262267
):
263268
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
269+
270+
if do_classifier_free_guidance:
271+
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
272+
if text_attention_mask is not None and negative_text_attention_mask is not None:
273+
encoder_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
274+
else:
275+
encoder_attention_mask = None
276+
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
277+
condition = jnp.concatenate([condition] * 2)
278+
264279
def high_noise_branch(operands):
265280
latents_input, ts_input, pe_input, ie_input = operands
266281
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
267282
noise_pred, latents_out = transformer_forward_pass(
268283
high_noise_graphdef, high_noise_state, high_noise_rest,
269284
latents_input, ts_input, pe_input,
270285
do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high,
271-
encoder_hidden_states_image=ie_input
286+
encoder_hidden_states_image=ie_input,
287+
encoder_attention_mask=encoder_attention_mask,
272288
)
273289
return noise_pred, latents_out
274290

@@ -279,7 +295,8 @@ def low_noise_branch(operands):
279295
low_noise_graphdef, low_noise_state, low_noise_rest,
280296
latents_input, ts_input, pe_input,
281297
do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low,
282-
encoder_hidden_states_image=ie_input
298+
encoder_hidden_states_image=ie_input,
299+
encoder_attention_mask=encoder_attention_mask,
283300
)
284301
return noise_pred, latents_out
285302

0 commit comments

Comments
 (0)