Skip to content

Commit 1b67d71

Browse files
committed
scan issue fixed in wan 2.1
1 parent 9ee2901 commit 1b67d71

1 file changed

Lines changed: 21 additions & 12 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from maxdiffusion import max_logging
1516
from maxdiffusion.image_processor import PipelineImageInput
1617
from .wan_pipeline import WanPipeline, transformer_forward_pass
1718
from ...models.wan.transformers.transformer_wan import WanModel
@@ -197,33 +198,41 @@ def __call__(
197198
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
198199
)
199200

201+
if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
202+
max_logging.log("[DEBUG] Priming scheduler state...")
203+
t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
204+
dummy_noise = jnp.zeros_like(latents)
205+
# This call initializes the internal state arrays
206+
_, scheduler_state = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
207+
max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}")
208+
200209
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
210+
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
211+
latents = jax.device_put(latents, data_sharding)
212+
condition = jax.device_put(condition, data_sharding)
213+
if first_frame_mask is not None:
214+
first_frame_mask = jax.device_put(first_frame_mask, data_sharding)
201215

202216
p_run_inference = partial(
203217
run_inference_2_1_i2v,
218+
graphdef=graphdef,
219+
sharded_state=state,
220+
rest_of_state=rest_of_state,
204221
guidance_scale=guidance_scale,
205222
num_inference_steps=num_inference_steps,
206223
scheduler=self.scheduler,
207-
image_embeds=image_embeds,
208-
expand_timesteps=self.config.expand_timesteps,
209-
first_frame_mask=first_frame_mask,
224+
expand_timesteps=self.config.expand_timesteps
210225
)
211226

212-
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
213-
latents = jax.device_put(latents, data_sharding)
214-
condition = jax.device_put(condition, data_sharding)
215-
if first_frame_mask is not None:
216-
first_frame_mask = jax.device_put(first_frame_mask, data_sharding)
217-
227+
218228
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
219229
latents = p_run_inference(
220-
graphdef=graphdef,
221-
sharded_state=state,
222-
rest_of_state=rest_of_state,
223230
latents=latents,
224231
condition=condition,
225232
prompt_embeds=prompt_embeds,
226233
negative_prompt_embeds=negative_prompt_embeds,
234+
image_embeds=image_embeds,
235+
first_frame_mask=first_frame_mask,
227236
scheduler_state=scheduler_state,
228237
rng=inference_rng,
229238
)

0 commit comments

Comments
 (0)