Skip to content

Commit 3516f7b

Browse files
committed
fix multiple issues on GPU
1 parent 727fdcb commit 3516f7b

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262

6363
from google.cloud import storage
6464

65+
libcudart = cdll.LoadLibrary("libcudart.so")
66+
6567
FrozenDict = core.frozen_dict.FrozenDict
6668

6769

@@ -78,12 +80,18 @@ def l2norm_pytree(x):
7880

7981
def activate_profiler(config):
8082
if jax.process_index() == 0 and config.enable_profiler:
81-
jax.profiler.start_trace(config.tensorboard_dir)
83+
if config.profiler == 'nsys':
84+
libcudart.cudaProfilerStart()
85+
else:
86+
jax.profiler.start_trace(config.tensorboard_dir)
8287

8388

8489
def deactivate_profiler(config):
8590
if jax.process_index() == 0 and config.enable_profiler:
86-
jax.profiler.stop_trace()
91+
if config.profiler == 'nsys':
92+
libcudart.cudaProfilerStop()
93+
else:
94+
jax.profiler.stop_trace()
8795

8896

8997
def initialize_summary_writer(config):

0 commit comments

Comments
 (0)