Skip to content

Commit 161ea22

Browse files
committed
Added randn_tensors
1 parent 6c43354 commit 161ea22

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def run_inference_2_1_i2v(
218218
num_inference_steps: int,
219219
scheduler: FlaxUniPCMultistepScheduler,
220220
scheduler_state,
221-
rng: jax.random.KeyArray,
221+
rng: jax.Array,
222222
expand_timesteps: bool,
223223
first_frame_mask: Optional[jnp.array],
224224
):

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __call__(
143143
image_embeds: Optional[jax.Array] = None,
144144
last_image: Optional[PipelineImageInput] = None,
145145
output_type: Optional[str] = "np",
146-
rng: Optional[jax.random.KeyArray] = None,
146+
rng: Optional[jax.Array] = None,
147147
):
148148
height = height or self.config.height
149149
width = width or self.config.width
@@ -236,7 +236,7 @@ def run_inference_2_2_i2v(
236236
num_inference_steps: int,
237237
scheduler: FlaxUniPCMultistepScheduler,
238238
scheduler_state,
239-
rng: jax.random.KeyArray,
239+
rng: jax.Array,
240240
expand_timesteps: bool,
241241
):
242242
do_classifier_free_guidance = guidance_scale > 1.0

0 commit comments

Comments
 (0)