Skip to content

Commit 6c43354

Browse files
committed
Added randn_tensors
1 parent 41cb3d0 commit 6c43354

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def prepare_latents(
8383
width: int,
8484
num_frames: int,
8585
dtype: jnp.dtype,
86-
rng: jax.random.KeyArray,
86+
rng: jax.Array,
8787
latents: Optional[jax.Array] = None,
8888
last_image: Optional[jax.Array] = None,
8989
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
@@ -133,7 +133,7 @@ def __call__(
133133
image_embeds: Optional[jax.Array] = None,
134134
last_image: Optional[PipelineImageInput] = None,
135135
output_type: Optional[str] = "np",
136-
rng: Optional[jax.random.KeyArray] = None,
136+
rng: Optional[jax.Array] = None,
137137
):
138138
height = height or self.config.height
139139
width = width or self.config.width

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def prepare_latents(
8282
width: int,
8383
num_frames: int,
8484
dtype: jnp.dtype,
85-
rng: jax.random.KeyArray,
85+
rng: jax.Array,
8686
latents: Optional[jax.Array] = None,
8787
last_image: Optional[jax.Array] = None,
8888
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:

0 commit comments

Comments
 (0)