|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from maxdiffusion import max_logging |
15 | 16 | from maxdiffusion.image_processor import PipelineImageInput |
16 | 17 | from .wan_pipeline import WanPipeline, transformer_forward_pass |
17 | 18 | from ...models.wan.transformers.transformer_wan import WanModel |
@@ -197,33 +198,41 @@ def __call__( |
197 | 198 | self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape |
198 | 199 | ) |
199 | 200 |
|
| 201 | + if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None: |
| 202 | + max_logging.log("[DEBUG] Priming scheduler state...") |
| 203 | + t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0] |
| 204 | + dummy_noise = jnp.zeros_like(latents) |
| 205 | + # This call initializes the internal state arrays |
| 206 | + _, scheduler_state = self.scheduler.step(scheduler_state, dummy_noise, t0, latents) |
| 207 | + max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}") |
| 208 | + |
200 | 209 | graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) |
| 210 | + data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding)) |
| 211 | + latents = jax.device_put(latents, data_sharding) |
| 212 | + condition = jax.device_put(condition, data_sharding) |
| 213 | + if first_frame_mask is not None: |
| 214 | + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) |
201 | 215 |
|
202 | 216 | p_run_inference = partial( |
203 | 217 | run_inference_2_1_i2v, |
| 218 | + graphdef=graphdef, |
| 219 | + sharded_state=state, |
| 220 | + rest_of_state=rest_of_state, |
204 | 221 | guidance_scale=guidance_scale, |
205 | 222 | num_inference_steps=num_inference_steps, |
206 | 223 | scheduler=self.scheduler, |
207 | | - image_embeds=image_embeds, |
208 | | - expand_timesteps=self.config.expand_timesteps, |
209 | | - first_frame_mask=first_frame_mask, |
| 224 | + expand_timesteps=self.config.expand_timesteps |
210 | 225 | ) |
211 | 226 |
|
212 | | - data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding)) |
213 | | - latents = jax.device_put(latents, data_sharding) |
214 | | - condition = jax.device_put(condition, data_sharding) |
215 | | - if first_frame_mask is not None: |
216 | | - first_frame_mask = jax.device_put(first_frame_mask, data_sharding) |
217 | | - |
| 227 | + |
218 | 228 | with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
219 | 229 | latents = p_run_inference( |
220 | | - graphdef=graphdef, |
221 | | - sharded_state=state, |
222 | | - rest_of_state=rest_of_state, |
223 | 230 | latents=latents, |
224 | 231 | condition=condition, |
225 | 232 | prompt_embeds=prompt_embeds, |
226 | 233 | negative_prompt_embeds=negative_prompt_embeds, |
| 234 | + image_embeds=image_embeds, |
| 235 | + first_frame_mask=first_frame_mask, |
227 | 236 | scheduler_state=scheduler_state, |
228 | 237 | rng=inference_rng, |
229 | 238 | ) |
|
0 commit comments