Skip to content

Commit d8150bb

Browse files
committed
seed added to config
1 parent ff0d222 commit d8150bb

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,4 @@ bwd_quantization_calibration_method: "absmax"
9090
qwix_module_path: ".*"
9191
jit_initializers: True
9292
enable_single_replica_ckpt_restoring: False
93+
seed: 0

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,8 @@ def prepare_latents(
10101010

10111011
shape = (batch_size, num_channels_latents, num_frames, height, width)
10121012
if generator is None:
1013-
generator = jax.random.key(0)
1013+
seed = getattr(self.config, "seed", 1) if hasattr(self, "config") else 1
1014+
generator = jax.random.key(seed)
10141015

10151016
latents = jax.random.normal(generator, shape, dtype=dtype or jnp.float32)
10161017
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)

0 commit comments

Comments
 (0)