Skip to content

Commit c29fdc4

Browse files
committed
Disable unsafe rng
1 parent d848983 commit c29fdc4

1 file changed

Lines changed: 0 additions & 9 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,6 @@ def get_git_commit_hash():
7676
return None
7777

7878
jax.config.update("jax_use_shardy_partitioner", True)
79-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
80-
# TF allocates extraneous GPU memory when using TFDS data
81-
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
82-
# tf.config.set_visible_devices([], "GPU")
83-
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
84-
max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.")
85-
os.environ["LIBTPU_INIT_ARGS"] = (
86-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
87-
)
8879

8980
def call_pipeline(config, pipeline, prompt, negative_prompt):
9081
model_key = config.model_name

0 commit comments

Comments
 (0)