Skip to content

Commit 98de3ed

Browse files
committed
Fix multiprocessing segfault
1 parent f41036d commit 98de3ed

4 files changed

Lines changed: 4 additions & 6 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ jobs:
5353
run: |
5454
ruff check .
5555
- name: PyTest
56-
run: |
57-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -x
56+
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
57+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
5858
# add_pull_ready:
5959
# if: github.ref != 'refs/heads/main'
6060
# permissions:

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
25-
25+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2626

2727
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
2828
dataset = dataset.with_format("tensorflow")[:]

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from PIL import Image
4141

4242
AUTOTUNE = tf.data.experimental.AUTOTUNE
43-
43+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
4444

4545
def make_data_iterator(
4646
config,

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from absl.testing import absltest
2424

2525
import numpy as np
26-
import pytest
2726
import tensorflow as tf
2827
import tensorflow.experimental.numpy as tnp
2928
import jax
@@ -70,7 +69,6 @@ class InputPipelineInterface(unittest.TestCase):
7069
def setUp(self):
7170
InputPipelineInterface.dummy_data = {}
7271

73-
@pytest.mark.skip(reason="Debug segfault")
7472
def test_make_dreambooth_train_iterator(self):
7573

7674
instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class"

0 commit comments

Comments
 (0)