Skip to content

Commit 7bfc64a

Browse files
committed
tokamax ring attention
1 parent 384d211 commit 7bfc64a

43 files changed

Lines changed: 7964 additions & 120 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ jobs:
5757
- name: PyTest
5858
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
5959
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
60-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61-
# add_pull_ready:
60+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
# add_pull_ready:q
6262
# if: github.ref != 'refs/heads/main'
6363
# permissions:
6464
# checks: read
6565
# pull-requests: write
6666
# needs: build
67-
# uses: ./.github/workflows/AddLabel.yml
67+
# uses: ./.github/workflows/AddLabel.yml

README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,24 @@ To generate images, run the following command:
597597
...
598598
```
599599

600+
### Ring Attention
601+
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
602+
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
603+
| -- | -- | -- | -- | -- | -- |
604+
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
605+
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
606+
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
607+
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |
608+
609+
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
610+
| -- | -- | -- | -- | -- | -- |
611+
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
612+
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
613+
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
614+
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |
615+
616+
(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)
617+
600618
## Flux
601619

602620
First make sure you have permissions to access the Flux repos in Huggingface.
@@ -772,4 +790,3 @@ This script will automatically format your code with `pyink` and help you identi
772790
773791
774792
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
775-

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
33

44
absl-py>=2.3.1
5-
accelerate>=1.13.0
65
aiofiles>=25.1.0
76
aiohappyeyeballs>=2.6.1
87
aiohttp>=3.13.3
@@ -81,7 +80,6 @@ isort>=8.0.1
8180
jaraco-functools>=4.4.0
8281
jax>=0.9.0
8382
jaxlib>=0.9.0
84-
jaxopt>=0.8.5
8583
jaxtyping>=0.3.9
8684
jinja2>=3.1.6
8785
keras>=3.13.1

docker_build_dependency_image.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ if [[ ${DEVICE} == "gpu" ]]; then
6666
export BASEIMAGE=ghcr.io/nvidia/jax:base
6767
fi
6868
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
69+
<<<<<<< HEAD
70+
else
71+
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
72+
if [[ ! -v BASEIMAGE ]]; then
73+
echo "Erroring out because BASEIMAGE is unset, please set it!"
74+
exit 1
75+
fi
76+
docker build --no-cache \
77+
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
78+
--build-arg COMMIT_HASH=${COMMIT_HASH} \
79+
--network=host \
80+
-t ${LOCAL_IMAGE_NAME} \
81+
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
82+
else
83+
docker build --no-cache \
84+
--network=host \
85+
--build-arg MODE=${MODE} \
86+
--build-arg JAX_VERSION=${JAX_VERSION} \
87+
-t ${LOCAL_IMAGE_NAME} \
88+
-f maxdiffusion_dependencies.Dockerfile .
89+
fi
90+
=======
6991
else
7092
# Default to maxdiffusion_dependencies.Dockerfile for non-GPU builds
7193
export BASEIMAGE=${BASEIMAGE:-python:3.12-slim-bullseye}
@@ -76,4 +98,5 @@ else
7698
--build-arg BASEIMAGE=${BASEIMAGE} \
7799
-t ${LOCAL_IMAGE_NAME} \
78100
-f maxdiffusion_dependencies.Dockerfile .
101+
>>>>>>> origin/main
79102
fi

maxdiffusion_dependencies.Dockerfile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
<<<<<<< HEAD
2+
# Use Python 3.12-slim-bullseye as the base image
3+
FROM python:3.12-slim-bullseye
4+
=======
15
# Use Python 3.12-slim-bullseye as the base image unless overridden
26
ARG BASEIMAGE=python:3.12-slim-bullseye
37
FROM $BASEIMAGE
8+
>>>>>>> origin/main
49

510
# Environment variable for no-cache-dir and pip root user warning
611
ENV PIP_NO_CACHE_DIR=1
@@ -13,8 +18,13 @@ ENV CLOUD_SDK_VERSION=latest
1318
# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors
1419
ENV DEBIAN_FRONTEND=noninteractive
1520

21+
<<<<<<< HEAD
22+
# Upgrade pip to the latest version
23+
RUN python -m pip install --upgrade pip --no-warn-script-location
24+
=======
1625
# Upgrade pip to the latest version and install uv
1726
RUN python -m pip install --upgrade pip uv --no-warn-script-location
27+
>>>>>>> origin/main
1828

1929
# Install system dependencies
2030
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/*
@@ -26,12 +36,26 @@ RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dea
2636
# Install the Google Cloud SDK
2737
RUN apt-get update && apt-get install -y google-cloud-sdk && rm -rf /var/lib/apt/lists/*
2838

39+
<<<<<<< HEAD
40+
# Install cloud-accelerator-diagnostics
41+
RUN pip install cloud-accelerator-diagnostics
42+
43+
# Install cloud-tpu-diagnostics
44+
RUN pip install cloud-tpu-diagnostics
45+
46+
# Install gcsfs
47+
RUN pip install gcsfs
48+
49+
# Install google-cloud-storage
50+
RUN pip install google-cloud-storage
51+
=======
2952
# Install diagnostic and storage dependencies using uv
3053
RUN python -m uv pip install --system \
3154
cloud-accelerator-diagnostics \
3255
cloud-tpu-diagnostics \
3356
gcsfs \
3457
google-cloud-storage
58+
>>>>>>> origin/main
3559

3660
# Args
3761
ARG MODE
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
ARG JAX_AI_IMAGE_BASEIMAGE
2+
3+
# JAX AI Base Image
4+
FROM $JAX_AI_IMAGE_BASEIMAGE
5+
6+
ARG JAX_AI_IMAGE_BASEIMAGE
7+
8+
ARG COMMIT_HASH
9+
10+
ENV COMMIT_HASH=$COMMIT_HASH
11+
12+
RUN mkdir -p /deps
13+
14+
# Set the working directory in the container
15+
WORKDIR /deps
16+
17+
# Copy all files from local workspace into docker container
18+
COPY . .
19+
20+
# Install Maxdiffusion Jax AI Image requirements
21+
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
22+
23+
# TODO: Remove the flax pin and fsspec overrides once flax stable version releases
24+
RUN if echo "$JAX_AI_IMAGE_BASEIMAGE" | grep -q "nightly"; then \
25+
echo "Nightly build detected: Installing specific Flax commit and fsspec." && \
26+
pip install --upgrade --force-reinstall git+https://github.com/google/flax.git@ef78d6584623511746be4824965cdef42b464583 && \
27+
pip install "fsspec==2025.10.0"; \
28+
fi
29+
30+
# Run the script available in JAX-AI-Image base image to generate the manifest file
31+
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

requirements.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
jax>=0.7.2
3+
jaxlib>=0.4.30
4+
grain
5+
google-cloud-storage>=2.17.0
6+
absl-py
7+
chex
8+
datasets
9+
flax>=0.12.0
10+
optax>=0.2.3
11+
torch>=2.6.0
12+
torchvision>=0.20.1
13+
ftfy
14+
tensorboard>=2.17.0
15+
tensorboardx>=2.6.2.2
16+
tensorboard-plugin-profile>=2.15.2
17+
Jinja2
18+
scikit-image
19+
parameterized
20+
Pillow
21+
pylint
22+
pyink
23+
pytest==8.2.2
24+
tensorflow>=2.17.0
25+
tensorflow-datasets>=4.9.6
26+
ruff>=0.1.5,<=0.2
27+
git+https://github.com/Lightricks/LTX-Video
28+
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
29+
opencv-python-headless==4.10.0.84
30+
orbax-checkpoint
31+
tokenizers==0.21.0
32+
huggingface_hub>=0.30.2
33+
transformers==4.51.0
34+
einops==0.8.0
35+
sentencepiece
36+
aqtp
37+
imageio==2.37.0
38+
imageio-ffmpeg==0.6.0
39+
hf_transfer>=0.1.9
40+
qwix@git+https://github.com/google/qwix.git

requirements_with_jax_ai_image.txt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Requirements for Building the MaxDifussion Docker Image
2+
# These requirements are additional to the dependencies present in the JAX AI base image.
3+
--extra-index-url https://download.pytorch.org/whl/cpu
4+
jax>=0.7.2
5+
jaxlib>=0.4.30
6+
grain
7+
google-cloud-storage>=2.17.0
8+
absl-py
9+
chex
10+
datasets
11+
flax>=0.12.0
12+
optax>=0.2.3
13+
torch>=2.6.0
14+
torchvision>=0.20.1
15+
ftfy
16+
tensorboard>=2.17.0
17+
tensorboardx>=2.6.2.2
18+
tensorboard-plugin-profile>=2.15.2
19+
Jinja2
20+
scikit-image
21+
parameterized
22+
Pillow
23+
pylint
24+
pyink
25+
pytest==8.2.2
26+
tensorflow>=2.17.0
27+
tensorflow-datasets>=4.9.6
28+
ruff>=0.1.5,<=0.2
29+
opencv-python-headless==4.10.0.84
30+
orbax-checkpoint
31+
tokenizers==0.21.0
32+
huggingface_hub>=0.30.2
33+
transformers==4.51.0
34+
tokamax
35+
einops==0.8.0
36+
sentencepiece
37+
aqtp
38+
imageio==2.37.0
39+
imageio-ffmpeg==0.6.0
40+
hf_transfer>=0.1.9
41+
qwix@git+https://github.com/google/qwix.git

setup.cfg

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[isort]
2+
default_section = FIRSTPARTY
3+
ensure_newline_before_comments = True
4+
force_grid_wrap = 0
5+
include_trailing_comma = True
6+
known_first_party = accelerate
7+
known_third_party =
8+
numpy
9+
torch
10+
torch_xla
11+
12+
line_length = 119
13+
lines_after_imports = 2
14+
multi_line_output = 3
15+
use_parentheses = True
16+
17+
[flake8]
18+
ignore = E203, E722, E501, E741, W503, W605
19+
max-line-length = 119
20+
per-file-ignores = __init__.py:F401

0 commit comments

Comments
 (0)