Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ jobs:
- name: Analysing the code with ruff
run: |
ruff check .
- name: version check
run: |
python --version
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -x
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from maxdiffusion import multihost_dataloading, max_logging

AUTOTUNE = tf.data.AUTOTUNE

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
dataset = dataset.with_format("tensorflow")[:]
Expand Down Expand Up @@ -50,7 +50,7 @@ def make_tf_iterator(
function=tokenize_fn,
batched=True,
remove_columns=[config.caption_column],
num_proc=1 if config.cache_latents_text_encoder_outputs else config.tokenize_captions_num_proc,
num_proc=None,
desc="Running tokenizer on train dataset",
)
# need to do it before load_as_tf_dataset
Expand All @@ -60,7 +60,7 @@ def make_tf_iterator(
function=image_transforms_fn,
batched=True,
remove_columns=[config.image_column],
num_proc=1 if config.cache_latents_text_encoder_outputs else config.transform_images_num_proc,
num_proc=None,
desc="Transforming images",
)
if config.cache_latents_text_encoder_outputs:
Expand Down
10 changes: 5 additions & 5 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from PIL import Image

AUTOTUNE = tf.data.experimental.AUTOTUNE

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def make_data_iterator(
config,
Expand Down Expand Up @@ -159,7 +159,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
function=tokenize_fn,
batched=True,
remove_columns=[INSTANCE_PROMPT_IDS],
num_proc=1,
num_proc=None,
desc="Running tokenizer on instance dataset",
)
rng = jax.random.key(config.seed)
Expand All @@ -177,7 +177,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
function=transform_images_fn,
batched=True,
remove_columns=[INSTANCE_IMAGES],
num_proc=1,
num_proc=None,
desc="Running vae on instance dataset",
)

Expand All @@ -188,7 +188,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
function=tokenize_fn,
batched=True,
remove_columns=[CLASS_PROMPT_IDS],
num_proc=1,
num_proc=None,
desc="Running tokenizer on class dataset",
)
transform_images_fn = partial(
Expand All @@ -204,7 +204,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
function=transform_images_fn,
batched=True,
remove_columns=[CLASS_IMAGES],
num_proc=1,
num_proc=None,
desc="Running vae on instance dataset",
)

Expand Down
4 changes: 1 addition & 3 deletions src/maxdiffusion/tests/input_pipeline_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import subprocess
import unittest
from absl.testing import absltest

import numpy as np
import pytest
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import jax
Expand Down Expand Up @@ -70,7 +68,6 @@ class InputPipelineInterface(unittest.TestCase):
def setUp(self):
InputPipelineInterface.dummy_data = {}

@pytest.mark.skip(reason="Debug segfault")
def test_make_dreambooth_train_iterator(self):

instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class"
Expand All @@ -85,6 +82,7 @@ def test_make_dreambooth_train_iterator(self):
os.path.join(THIS_DIR, "..", "configs", "base14.yml"),
"cache_latents_text_encoder_outputs=True",
"dataset_name=my_dreambooth_dataset",
"transform_images_num_proc=1",
f"instance_data_dir={instance_class_local_dir}",
f"class_data_dir={class_class_local_dir}",
"instance_prompt=photo of ohwx dog",
Expand Down
Loading