Skip to content

Commit 5f5749a

Browse files
check if torch is causing the issue
1 parent 333e4f5 commit 5f5749a

2 files changed

Lines changed: 41 additions & 18 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,27 +98,28 @@ jobs:
9898
# pip install tensorflow-cpu
9999
pip freeze
100100
101-
- name: Check per_device_batch_size
101+
- name: Check devices
102102
run: |
103103
python -c "import jax; print(jax.devices())"
104+
python verify_conflict.py
104105
105-
- name: Run MaxDiffusion Training
106-
run: |
107-
# This command is adapted from your DAG for a single-slice configuration.
108-
NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
109-
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110-
hardware=gpu \
111-
train_new_unet=true \
112-
train_text_encoder=false \
113-
cache_latents_text_encoder_outputs=true \
114-
per_device_batch_size=1 \
115-
attention=dot_product \
116-
activations_dtype=bfloat16 \
117-
weights_dtype=bfloat16 \
118-
max_train_steps=200 \
119-
enable_profiler=True \
120-
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121-
output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
106+
# - name: Run MaxDiffusion Training
107+
# run: |
108+
# # This command is adapted from your DAG for a single-slice configuration.
109+
# NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
110+
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
111+
# hardware=gpu \
112+
# train_new_unet=true \
113+
# train_text_encoder=false \
114+
# cache_latents_text_encoder_outputs=true \
115+
# per_device_batch_size=1 \
116+
# attention=dot_product \
117+
# activations_dtype=bfloat16 \
118+
# weights_dtype=bfloat16 \
119+
# max_train_steps=200 \
120+
# enable_profiler=True \
121+
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
122+
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
122123

123124
# jobs:
124125
# build:

verify_conflict.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
print("--- PyTorch vs. JAX Conflict Test ---")
2+
3+
print("\nStep 1: Attempting to import torch...")
4+
try:
5+
import torch
6+
print(f"Successfully imported torch version: {torch.__version__}")
7+
# This check will confirm you have the CPU-only version
8+
print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}")
9+
except Exception as e:
10+
print(f"Failed to import torch: {e}")
11+
12+
13+
print("\nStep 2: Now, attempting to initialize JAX...")
14+
try:
15+
import jax
16+
devices = jax.devices()
17+
print("\n--- RESULT: SUCCESS ---")
18+
print(f"JAX initialized correctly and found devices: {devices}")
19+
except Exception as e:
20+
print("\n--- RESULT: FAILURE ---")
21+
print("JAX failed to initialize after PyTorch was imported.")
22+
print(f"JAX Error: {e}")

0 commit comments

Comments
 (0)