Skip to content

Commit 6bcb87f

Browse files
committed
Merge branch 'main' into ltx2-dev
2 parents c9dafef + 85ba65e commit 6bcb87f

7 files changed

Lines changed: 8 additions & 18 deletions

File tree

.github/workflows/UploadDockerImages.yml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,3 @@ jobs:
4040
- name: build maxdiffusion jax nightly image
4141
run: |
4242
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
43-
44-
build-gpu-image:
45-
runs-on: ["self-hosted", "e2", "cpu"]
46-
steps:
47-
- uses: actions/checkout@v3
48-
- name: Cleanup old docker images
49-
run: docker system prune --all --force
50-
- name: build maxdiffusion jax stable stack gpu image
51-
run: |
52-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
53-
- name: build maxdiffusion jax nightly image
54-
run: |
55-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ MaxDiffusion supports
5252
* Wan2.1 text2vid (training and inference).
5353
* Wan2.2 text2vid (inference).
5454

55+
**Note on GPU Support:** GPU support is not actively maintained, but contributions are welcome
56+
5557

5658
# Table of Contents
5759

@@ -176,7 +178,7 @@ After installation completes, run the training script.
176178

177179
```bash
178180
BUCKET_NAME=my-bucket
179-
gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
181+
gcloud storage cp --recursive $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
180182
```
181183

182184
Now run the training command:
@@ -703,4 +705,3 @@ This script will automatically format your code with `pyink` and help you identi
703705
704706
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
705707
706-

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ jaxlib>=0.4.30
44
grain
55
google-cloud-storage>=2.17.0
66
absl-py
7+
chex
78
datasets
89
flax>=0.12.0
910
optax>=0.2.3

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ jaxlib>=0.4.30
66
grain
77
google-cloud-storage>=2.17.0
88
absl-py
9+
chex
910
datasets
1011
flax>=0.12.0
1112
optax>=0.2.3

src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Load and save a checkpoint. This is useful for uploading checkpoints to gcs
1818
and later loading them from gcs directly.
1919
After calling this script, use gsutil to upload the weights to a bucket:
20-
gsutil -m cp -r sd-model-finetuned gs://<your-bucket>/sd_checkpoint/
20+
gcloud storage cp --recursive sd-model-finetuned gs://<your-bucket>/sd_checkpoint/
2121
"""
2222

2323
from typing import Sequence

src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Load and save a checkpoint. This is useful for uploading checkpoints to gcs
1818
and later loading them from gcs directly.
1919
After calling this script, use gsutil to upload the weights to a bucket:
20-
gsutil -m cp -r sdxl-model-finetuned gs://<your-bucket>/sdxl_1.0_base/
20+
gcloud storage cp --recursive sdxl-model-finetuned gs://<your-bucket>/sdxl_1.0_base/
2121
"""
2222

2323
from typing import Sequence

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
415415
max_logging.log(f"Saving final checkpoint for step {step}")
416416
self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params)
417417
self.checkpointer.checkpoint_manager.wait_until_finished()
418-
# load new state for trained tranformer
418+
# load new state for trained transformer
419419
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
420420
return pipeline
421421

0 commit comments

Comments
 (0)