Skip to content

Commit 6d9239a

Browse files
committed
resolving comments
1 parent 4281c24 commit 6d9239a

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
5959
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
6060
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61-
# add_pull_ready:q
61+
# add_pull_ready
6262
# if: github.ref != 'refs/heads/main'
6363
# permissions:
6464
# checks: read

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def activate_profiler(config):
8181
if jax.process_index() == 0 and config.enable_profiler:
8282
trace_dir = config.tensorboard_dir
8383
if trace_dir.startswith("gs://"):
84-
trace_dir = "/tmp/profiler_traces"
84+
trace_dir = os.path.join("/tmp/profiler_traces", config.run_name)
8585
os.makedirs(trace_dir, exist_ok=True)
8686
max_logging.log(f"Starting profiler trace in: {trace_dir}")
8787
jax.profiler.start_trace(trace_dir)
@@ -93,7 +93,7 @@ def deactivate_profiler(config):
9393

9494
trace_dir = config.tensorboard_dir
9595
if trace_dir.startswith("gs://"):
96-
local_dir = "/tmp/profiler_traces"
96+
local_dir = os.path.join("/tmp/profiler_traces", config.run_name)
9797
if os.path.exists(local_dir):
9898
max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...")
9999
client = storage.Client()

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self, dim: int, mode: str) -> None:
108108
torch.nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest"), torch.nn.Conv2d(dim, dim // 2, 3, padding=1)
109109
)
110110
elif mode == "upsample3d":
111-
raise Exception("downsample3d not supported")
111+
raise Exception("upsample3d not supported")
112112

113113
elif mode == "downsample2d":
114114
self.resample = torch.nn.Sequential(torch.nn.ZeroPad2d((0, 1, 0, 1)), torch.nn.Conv2d(dim, dim, 3, stride=(2, 2)))

0 commit comments

Comments
 (0)