We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 727fdcb commit 3516f7bCopy full SHA for 3516f7b
1 file changed
src/maxdiffusion/max_utils.py
@@ -62,6 +62,8 @@
62
63
from google.cloud import storage
64
65
+libcudart = cdll.LoadLibrary("libcudart.so")
66
+
67
FrozenDict = core.frozen_dict.FrozenDict
68
69
@@ -78,12 +80,18 @@ def l2norm_pytree(x):
78
80
79
81
def activate_profiler(config):
82
if jax.process_index() == 0 and config.enable_profiler:
- jax.profiler.start_trace(config.tensorboard_dir)
83
+ if config.profiler == 'nsys':
84
+ libcudart.cudaProfilerStart()
85
+ else:
86
+ jax.profiler.start_trace(config.tensorboard_dir)
87
88
89
def deactivate_profiler(config):
90
- jax.profiler.stop_trace()
91
92
+ libcudart.cudaProfilerStop()
93
94
+ jax.profiler.stop_trace()
95
96
97
def initialize_summary_writer(config):
0 commit comments