Skip to content

Commit 10fa395

Browse files
committed
resolved README conflict
2 parents 738afda + f279995 commit 10fa395

35 files changed

Lines changed: 527 additions & 93 deletions

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ jobs:
3131
strategy:
3232
fail-fast: false
3333
matrix:
34-
tpu-type: ["v4-8"]
34+
tpu-type: ["v5p-8"]
3535
name: "TPU test (${{ matrix.tpu-type }})"
36-
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
36+
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
3737
steps:
3838
- uses: actions/checkout@v4
3939
- name: Set up Python 3.12
@@ -54,7 +54,7 @@ jobs:
5454
ruff check .
5555
- name: PyTest
5656
run: |
57-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
57+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
5858
# add_pull_ready:
5959
# if: github.ref != 'refs/heads/main'
6060
# permissions:

.github/workflows/XLML.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Add Testgrid Link to PR
22

33
on:
44
pull_request:
5-
types: [opened, synchronize]
5+
types: [opened]
66

77
jobs:
88
add_testgrid_link:

end_to_end/tpu/test_sdxl_training_loss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ done
1212
TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
1313
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
1414
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \
15-
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \
15+
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl resolution=1024 per_device_batch_size=1 \
1616
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \
1717
run_name=$RUN_NAME \
1818
output_dir=$OUTPUT_DIR "

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
from abc import ABC
18-
from flax import nnx
1918
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2019
from ..pipelines.wan.wan_pipeline import WanPipeline
2120
from .. import max_logging, max_utils
@@ -42,7 +41,7 @@ def _create_optimizer(self, model, config, learning_rate):
4241
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
4342
)
4443
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
45-
return nnx.Optimizer(model, tx), learning_rate_scheduler
44+
return tx, learning_rate_scheduler
4645

4746
def load_wan_configs_from_orbax(self, step):
4847
max_logging.log("Restoring stable diffusion configs")

src/maxdiffusion/configs/base14.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,5 @@ cache_dreambooth_dataset: False
231231
quantization: ''
232232
# Shard the range finding operation for quantization. By default this is set to number of slices.
233233
quantization_local_shard_count: -1
234+
use_qwix_quantization: False
234235
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base21.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,4 @@ quantization: ''
232232
# Shard the range finding operation for quantization. By default this is set to number of slices.
233233
quantization_local_shard_count: -1
234234
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
235+
use_qwix_quantization: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,5 @@ cache_dreambooth_dataset: False
246246
quantization: ''
247247
# Shard the range finding operation for quantization. By default this is set to number of slices.
248248
quantization_local_shard_count: -1
249+
use_qwix_quantization: False
249250
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,5 +276,6 @@ controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Goo
276276
quantization: ''
277277
# Shard the range finding operation for quantization. By default this is set to number of slices.
278278
quantization_local_shard_count: -1
279+
use_qwix_quantization: False
279280
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
280281

0 commit comments

Comments
 (0)