Skip to content

Commit 8e364d9

Browse files
authored
Merge branch 'main' into cross_self_attention_switch
2 parents 5182222 + d843dc0 commit 8e364d9

34 files changed

Lines changed: 2534 additions & 335 deletions

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

README.md

Lines changed: 374 additions & 69 deletions
Large diffs are not rendered by default.

docs/dgx_spark.md

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# MaxDiffusion on Nvidia DGX Spark GPU: A complete User Guide
2+
3+
This guide provides a detailed step-by-step walkthrough for setting up and running the maxdiffusion library within a custom Docker environment on an ARM-based machine with NVIDIA GPU support. We will cover everything from building the optimized Docker image to generating your first image and retrieving it successfully.
4+
5+
## Prerequisites
6+
7+
Before you begin, ensure you have the following:
8+
9+
- Access to [Nvidia DGX Spark Box](https://www.nvidia.com/en-us/products/workstations/dgx-spark/).
10+
- The maxdiffusion source code cloned onto the machine.
11+
- Branch: dgx_spark
12+
- An internet connection for the initial Docker build and for downloading models (if not cached).
13+
14+
## Part 1: Building the Optimized Docker Image
15+
16+
The foundation of a smooth workflow is a well-built Docker image. The following Dockerfile is optimized for build speed by caching dependencies, ensuring that code changes don't require a full reinstall of all libraries.
17+
18+
### Step1: Create the Dockerfile
19+
20+
In the root directory of your maxdiffusion project, create a file named box.Dockerfile and paste the following content into it.
21+
22+
```docker
23+
# Nvidia Base image for ARM64 with CUDA support
24+
# As JAX AI Image as it currently doesn't support ARM builds.
25+
FROM nvcr.io/nvidia/cuda-dl-base@sha256:3631d968c12ef22b1dfe604de63dbc71a55f3ffcc23a085677a6d539d98884a4
26+
27+
# Set environment variables (these rarely change)
28+
ENV PIP_BREAK_SYSTEM_PACKAGES=1
29+
ENV DEBIAN_FRONTEND=noninteractive
30+
31+
# Install system-level dependencies (these change very infrequently)
32+
RUN apt-get update && apt-get install -y python3 python3-pip
33+
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \
34+
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
35+
36+
WORKDIR /app
37+
38+
# --- Dependency Installation Layer ---
39+
# First, copy only the requirements file to leverage caching
40+
COPY requirements.txt .
41+
42+
# Install dependencies from requirements.txt
43+
RUN pip install -r requirements.txt
44+
45+
# Install other major Python libraries in separate layers for better caching
46+
RUN pip install "jax[cuda13-local]==0.7.2"
47+
48+
# --- Application Code Layer ---
49+
# Now, copy your application source code. This layer is rebuilt only when your code changes.
50+
COPY . .
51+
52+
# Install the maxdiffusion package from the copied source
53+
RUN pip install .
54+
55+
# Set a default command to keep the container running for interactive use
56+
CMD ["/bin/bash"]
57+
```
58+
59+
### Step2: Build the Image
60+
61+
Open your terminal on DGX Spark, navigate to the root directory of the maxdiffusion project, and run the build command:
62+
63+
```bash
64+
docker build -f box.Dockerfile -t maxdiffusion-arm-gpu .
65+
```
66+
67+
This command will execute the steps in your Dockerfile, download the necessary layers, install all dependencies, and create a local Docker image named `maxdiffusion-arm-gpu`. The first build may take some time. Subsequent builds will be much faster if you only change the source code.
68+
69+
## Part 2: Running the Container for Image Generation
70+
71+
To run the image generator effectively, we need to connect our local machine's folders to the container. This prevents re-downloading models and makes it easy to retrieve the output images.
72+
73+
### Step 1: Create a Local Output Directory
74+
75+
On your DGX Spark, create a directory to store the generated images.
76+
77+
```bash
78+
mkdir -p ~/maxdiffusion_output
79+
```
80+
81+
### Step 2a: Launch the Container with Volume Mounts
82+
83+
Run the following command to start an interactive session inside your container. This command links your Hugging Face cache (to avoid re-downloading models) and the output directory you just created.
84+
85+
```bash
86+
docker run -it --gpus all \
87+
-v ~/.cache/huggingface:/root/.cache/huggingface \
88+
-v ~/maxdiffusion_output:/tmp \
89+
maxdiffusion-arm-gpu
90+
```
91+
Your terminal prompt will change, indicating you are now inside the running container.
92+
93+
#### Step 2b: Log in to Hugging Face (First-Time Setup)
94+
95+
You must do this once to download the required model weights.
96+
97+
```bash
98+
# [Inside the Docker Container]
99+
huggingface-cli login
100+
```
101+
102+
You will be prompted to paste a Hugging Face User Access Token.
103+
104+
1. Go to[ huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) in your web browser.
105+
106+
2. Copy your token (or create a new one with write permissions).
107+
108+
3. Paste the token into the terminal and press Enter.
109+
110+
111+
## Part 3: Generating Your First Image
112+
113+
Now that you are inside the container's interactive shell, you can execute the image generation script. Run the following command:
114+
115+
```bash
116+
NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
117+
```
118+
The script will initialize, use the models from your mounted cache, and begin the generation process.
119+
120+
## Part 4: Accessing Your Generated Image
121+
122+
The generation script saves the final image to its working directory (/app) inside the container. Here is the complete workflow to get that image onto your Laptop.
123+
124+
### Step 1: Copy the Image from Container to DGX Spark
125+
126+
Open a new terminal window. Do not close the terminal where the container is running.
127+
First, find your container's ID:
128+
129+
```bash
130+
docker ps
131+
```
132+
133+
Look for the container with the image maxdiffusion-arm-gpu and note its ID (e.g., 9049895399fc).
134+
Now, copy the image from the container to a temporary location on DGX Spark and fix its permissions.
135+
136+
```bash
137+
# Copy the file to the /tmp/ directory on DGX Spark
138+
docker cp 9049895399fc:/app/flux_0.png /tmp/flux_0.png
139+
140+
# Change the file's owner to your user to avoid permission errors
141+
sudo chown username:username /tmp/flux_0.png
142+
```
143+
144+
### Step 2: Copy the Image from DGX Spark to Your Laptop
145+
146+
Now, open the Terminal app on your Laptop and use the scp (secure copy) command to download the file from DGX Spark.
147+
148+
```bash
149+
scp username@spark:/tmp/flux_0.png .
150+
```
151+
152+
This command will download flux_0.png to the current directory on your Laptop. You can now view your generated image!
153+
154+
## Troubleshooting and Common Pitfalls
155+
156+
Here are solutions to common issues you might encounter:
157+
- Error: `pip: command not found` during Docker build.
158+
- **Cause**: The base Docker image doesn't have pip in the system's default PATH.
159+
- **Solution**: The provided Dockerfile fixes this by explicitly installing python3-pip and using update-alternatives to create the necessary symbolic links.
160+
- Error: `externally-managed-environment` during `pip install`.
161+
- **Cause**: Newer versions of Debian/Ubuntu protect system Python packages from being modified by pip.
162+
- **Solution**: The `ENV PIP_BREAK_SYSTEM_PACKAGES=1` line in the `Dockerfile` safely bypasses this protection within the container's isolated environment.
163+
- Error: `OSError: ...is not a local folder and is not a valid model identifier`
164+
- **Cause**: The script is trying to download models from the Hugging Face Hub because it cannot find them locally.
165+
- **Solution**: This is solved by launching the container with the `-v ~/.cache/huggingface:/root/.cache/huggingface` flag, which gives the container access to your local model cache.
166+
- Error: `open ... permission denied` when trying to access a copied file.
167+
- **Cause**: Files copied from a Docker container with docker cp are owned by the root user by default.
168+
- **Solution**: After copying the file to the DGX Spark, immediately run `sudo chown your_user:your_user /path/to/file` to take ownership before trying to access or transfer it.
169+
- Can't find the generated image.
170+
- **Cause**: The script may not be saving the image to the directory specified by the output_dir argument.
171+
- **Solution**: Always check the script's source code to confirm the final save location. As we discovered, generate_flux.py saves to the current working directory (/app), not /tmp. Knowing this allows you to copy the file from the correct location.
172+
- If a process requires more memory than the available RAM, your system will crash with an "Out-of-Memory" (OOM) error.
173+
- `Swap memory is your safety net.` It's a designated space on your hard drive that the operating system uses as a "virtual" extension of your RAM. When RAM is full, the system moves less active data to the slower swap space, freeing up RAM for the immediate task. While it's slower than RAM, it's infinitely better than a system crash, ensuring your long-running training or generation jobs can complete successfully. For a machine with 119GB of RAM, adding 64GB of swap provides a robust buffer for memory-intensive operations.
174+
- Step 1: Create a 64GB Swap File
175+
- Run these commands on your DGX Spark to create, format, and enable a permanent 64GB swap file.
176+
177+
```bash
178+
# Instantly allocate a 64GB file
179+
sudo fallocate -l 64G /swapfile
180+
# Set secure permissions (only root can access)
181+
sudo chmod 600 /swapfile
182+
# Format the file as swap space
183+
sudo mkswap /swapfile
184+
# Enable the swap file for the current session
185+
sudo swapon /swapfile
186+
# Add the swap file to the system's startup configuration to make it permanent
187+
echo '/swapfile none swap sw 0 0' | sudo tee -a /etc/fstab
188+
```
189+
190+
- Step 2: Verify Swap is Active
191+
- Check that the swap space is correctly configured.
192+
193+
```bash
194+
free -h
195+
# The output should now show a 64GB total for Swap.
196+
```

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
2-
jax>=0.6.2
2+
jax>=0.7.2
33
jaxlib>=0.4.30
44
grain
55
google-cloud-storage>=2.17.0
66
absl-py
77
datasets
8-
flax>=0.11.0
8+
flax>=0.12.0
99
optax>=0.2.3
1010
torch>=2.6.0
1111
torchvision>=0.20.1

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Requirements for Building the MaxDifussion Docker Image
22
# These requirements are additional to the dependencies present in the JAX AI base image.
33
--extra-index-url https://download.pytorch.org/whl/cpu
4-
jax>=0.6.2
4+
jax>=0.7.2
55
jaxlib>=0.4.30
66
grain
77
google-cloud-storage>=2.17.0
88
absl-py
99
datasets
10-
flax>=0.10.2
10+
flax>=0.12.0
1111
optax>=0.2.3
1212
torch>=2.6.0
1313
torchvision>=0.20.1

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager(
6161
if checkpoint_type == FLUX_CHECKPOINT:
6262
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6363
elif checkpoint_type == WAN_CHECKPOINT:
64-
item_names = ("wan_state", "wan_config")
64+
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
6565
else:
6666
item_names = (
6767
"unet_config",

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import jax
2121
import numpy as np
22+
from typing import Optional, Tuple
2223
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2324
from ..pipelines.wan.wan_pipeline import WanPipeline
2425
from .. import max_logging, max_utils
@@ -33,6 +34,7 @@ class WanCheckpointer(ABC):
3334
def __init__(self, config, checkpoint_type):
3435
self.config = config
3536
self.checkpoint_type = checkpoint_type
37+
self.opt_state = None
3638

3739
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
3840
self.config.checkpoint_dir,
@@ -49,15 +51,15 @@ def _create_optimizer(self, model, config, learning_rate):
4951
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5052
return tx, learning_rate_scheduler
5153

52-
def load_wan_configs_from_orbax(self, step):
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
5355
if step is None:
5456
step = self.checkpoint_manager.latest_step()
5557
max_logging.log(f"Latest WAN checkpoint step: {step}")
5658
if step is None:
57-
return None
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
5861
max_logging.log(f"Loading WAN checkpoint from step {step}")
5962
metadatas = self.checkpoint_manager.item_metadata(step)
60-
6163
transformer_metadata = metadatas.wan_state
6264
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
6365
params_restore = ocp.args.PyTreeRestore(
@@ -73,27 +75,32 @@ def load_wan_configs_from_orbax(self, step):
7375
step=step,
7476
args=ocp.args.Composite(
7577
wan_state=params_restore,
76-
# wan_state=params_restore_util_way,
7778
wan_config=ocp.args.JsonRestore(),
7879
),
7980
)
80-
return restored_checkpoint
81+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
82+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
83+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
84+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
85+
return restored_checkpoint, step
8186

8287
def load_diffusers_checkpoint(self):
8388
pipeline = WanPipeline.from_pretrained(self.config)
8489
return pipeline
8590

86-
def load_checkpoint(self, step=None):
87-
restored_checkpoint = self.load_wan_configs_from_orbax(step)
88-
91+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
92+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
93+
opt_state = None
8994
if restored_checkpoint:
9095
max_logging.log("Loading WAN pipeline from checkpoint")
9196
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
97+
if "opt_state" in restored_checkpoint["wan_state"].keys():
98+
opt_state = restored_checkpoint["wan_state"]["opt_state"]
9299
else:
93100
max_logging.log("No checkpoint found, loading default pipeline.")
94101
pipeline = self.load_diffusers_checkpoint()
95102

96-
return pipeline
103+
return pipeline, opt_state, step
97104

98105
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
99106
"""Saves the training state and model configurations."""

0 commit comments

Comments
 (0)