Skip to content

Commit 5d9d758

Browse files
committed
chore: Add tokamax dependency and unsafe RNG utilities
1 parent f1ff3cc commit 5d9d758

5 files changed

Lines changed: 31 additions & 12 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
62-
# add_pull_ready:
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6362
# if: github.ref != 'refs/heads/main'
6463
# permissions:
6564
# checks: read

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ celerybeat-schedule
9797
*.sage.py
9898

9999
# Environments
100+
.history
100101
.env
101102
.venv
102103
env/

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ftfy
1313
tensorboard>=2.17.0
1414
tensorboardx>=2.6.2.2
1515
tensorboard-plugin-profile>=2.15.2
16+
tokamax
1617
Jinja2
1718
scikit-image
1819
parameterized

src/maxdiffusion/generate_wan.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ def delete_file(file_path: str):
6262

6363

6464
jax.config.update("jax_use_shardy_partitioner", True)
65+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
66+
# TF allocates extraneous GPU memory when using TFDS data
67+
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
68+
# tf.config.set_visible_devices([], "GPU")
69+
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
70+
max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.")
71+
os.environ["LIBTPU_INIT_ARGS"] = (
72+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
73+
)
6574

6675
def get_pipeline(model_name: str):
6776
if model_name == "wan2.1":

src/maxdiffusion/max_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,26 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504-
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
504+
attention_is_tokamax = "tokamax" in config.attention
505+
user_block_sizes:Dict[str, int] = config.flash_block_sizes
506+
if attention_is_tokamax:
507+
max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
508+
"Hence following flash block properties specified will be ignored:"
509+
f"block_q: {user_block_sizes['block_q']},"
510+
f"block_q_dq: {user_block_sizes.get('block_q_dq')},"
511+
f"block_kv_dq: {user_block_sizes.get('block_kv_dq')},"
512+
f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}"
513+
)
505514
flash_block_sizes = splash_attention_kernel.BlockSizes(
506-
block_q=config.flash_block_sizes["block_q"],
507-
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
508-
block_kv=config.flash_block_sizes["block_kv"],
509-
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
510-
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
511-
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
512-
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
513-
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
514-
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
515+
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"],
516+
block_kv_compute=user_block_sizes["block_kv_compute"],
517+
block_kv=user_block_sizes["block_kv"],
518+
block_q_dkv=user_block_sizes["block_q_dkv"],
519+
block_kv_dkv=user_block_sizes["block_kv_dkv"],
520+
block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"],
521+
block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"),
522+
block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"),
523+
use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"),
515524
)
516525
return flash_block_sizes
517526

0 commit comments

Comments
 (0)