1+ #! /bin/bash
2+ set -e
3+ set -u
4+ set -o pipefail
5+
6+ : " ${NNODES:? Must set NNODES} "
7+ : " ${NODE_RANK:? Must set NODE_RANK} "
8+ : " ${JAX_COORDINATOR_PORT:? Must set JAX_COORDINATOR_PORT} "
9+ : " ${JAX_COORDINATOR_ADDRESS:? Must set JAX_COORDINATOR_ADDRESS} "
10+ : " ${GPUS_PER_NODE:? Must set GPUS_PER_NODE} "
11+ : " ${COMMAND:? Must set COMMAND} "
12+
13+
14+ export GPUS_PER_NODE=$GPUS_PER_NODE
15+ export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT
16+ export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS
17+
18+ set_nccl_gpudirect_tcpx_specific_configuration () {
19+ if [[ " $USE_GPUDIRECT " == " tcpx" ]] || [[ " $USE_GPUDIRECT " == " fastrak" ]]; then
20+ export CUDA_DEVICE_MAX_CONNECTIONS=1
21+ export NCCL_CROSS_NIC=0
22+ export NCCL_DEBUG=INFO
23+ export NCCL_DYNAMIC_CHUNK_SIZE=524288
24+ export NCCL_NET_GDR_LEVEL=PIX
25+ export NCCL_NVLS_ENABLE=0
26+ export NCCL_P2P_NET_CHUNKSIZE=524288
27+ export NCCL_P2P_NVL_CHUNKSIZE=1048576
28+ export NCCL_P2P_PCI_CHUNKSIZE=524288
29+ export NCCL_PROTO=Simple
30+ export NCCL_SOCKET_IFNAME=eth0
31+ export NVTE_FUSED_ATTN=1
32+ export TF_CPP_MAX_LOG_LEVEL=100
33+ export TF_CPP_VMODULE=profile_guided_latency_estimator=10
34+ export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
35+ shopt -s globstar nullglob
36+ IFS=:$IFS
37+ set -- /usr/local/cuda-* /compat
38+ export LD_LIBRARY_PATH=" ${1+: " $* " } :${LD_LIBRARY_PATH} :/usr/local/tcpx/lib64"
39+ IFS=${IFS# ?}
40+ shopt -u globstar nullglob
41+
42+ if [[ " $USE_GPUDIRECT " == " tcpx" ]]; then
43+ echo " Using GPUDirect-TCPX"
44+ export NCCL_ALGO=Ring
45+ export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
46+ export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
47+ export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
48+ export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
49+ export NCCL_GPUDIRECTTCPX_RX_BINDINGS=" eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
50+ export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
51+ export NCCL_GPUDIRECTTCPX_TX_BINDINGS=" eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
52+ export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
53+ export NCCL_MAX_NCHANNELS=12
54+ export NCCL_MIN_NCHANNELS=12
55+ export NCCL_NSOCKS_PERTHREAD=4
56+ export NCCL_P2P_PXN_LEVEL=0
57+ export NCCL_SOCKET_NTHREADS=1
58+ elif [[ " $USE_GPUDIRECT " == " fastrak" ]]; then
59+ echo " Using GPUDirect-TCPFasTrak"
60+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
61+ export NCCL_ALGO=Ring,Tree
62+ export NCCL_BUFFSIZE=8388608
63+ export NCCL_FASTRAK_CTRL_DEV=eth0
64+ export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
65+ export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
66+ export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
67+ export NCCL_FASTRAK_NUM_FLOWS=2
68+ export NCCL_FASTRAK_USE_LLCM=1
69+ export NCCL_FASTRAK_USE_SNAP=1
70+ export NCCL_MIN_NCHANNELS=4
71+ export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
72+ export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
73+ export NCCL_TUNER_PLUGIN=libnccl-tuner.so
74+ fi
75+ else
76+ echo " NOT using GPUDirect"
77+ fi
78+ }
79+
80+ echo " LD_LIBRARY_PATH ${LD_LIBRARY_PATH} "
81+
82+ set_nccl_gpudirect_tcpx_specific_configuration
83+
84+ wait_all_success_or_exit () {
85+ # https://www.baeldung.com/linux/background-process-get-exit-code
86+ local pids=(" $@ " )
87+ while [[ ${# pids[@]} -ne 0 ]]; do
88+ all_success=" true"
89+ for pid in " ${pids[@]} " ; do
90+ code=$( non_blocking_wait " $pid " )
91+ if [[ $code -ne 127 ]]; then
92+ if [[ $code -ne 0 ]]; then
93+ echo " PID $pid failed with exit code $code "
94+ exit " $code "
95+ fi
96+ else
97+ all_success=" false"
98+ fi
99+ done
100+ if [[ $all_success == " true" ]]; then
101+ echo " All pids succeeded"
102+ break
103+ fi
104+ sleep 5
105+ done
106+ }
107+ non_blocking_wait () {
108+ # https://www.baeldung.com/linux/background-process-get-exit-code
109+ local pid=$1
110+ local code=127 # special code to indicate not-finished
111+ if [[ ! -d " /proc/$pid " ]]; then
112+ wait " $pid "
113+ code=$?
114+ fi
115+ echo $code
116+ }
117+
118+ resolve_coordinator_ip () {
119+ local lookup_attempt=1
120+ local max_coordinator_lookups=500
121+ local coordinator_found=false
122+ local coordinator_ip_address=" "
123+
124+ echo " Coordinator Address $JAX_COORDINATOR_ADDRESS "
125+
126+ while [[ " $coordinator_found " = false && $lookup_attempt -le $max_coordinator_lookups ]]; do
127+ coordinator_ip_address=$( nslookup " $JAX_COORDINATOR_ADDRESS " 2> /dev/null | awk ' /^Address: / { print $2 }' | head -n 1)
128+ if [[ -n " $coordinator_ip_address " ]]; then
129+ coordinator_found=true
130+ echo " Coordinator IP address: $coordinator_ip_address "
131+ export JAX_COORDINATOR_IP=$coordinator_ip_address
132+ return 0
133+ else
134+ echo " Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt , retrying..."
135+ (( lookup_attempt++ ))
136+ sleep 1
137+ fi
138+ done
139+
140+ if [[ " $coordinator_found " = false ]]; then
141+ echo " Failed to resolve coordinator address after $max_coordinator_lookups attempts."
142+ return 1
143+ fi
144+ }
145+
146+ # Resolving coordinator IP
147+ set +e
148+ resolve_coordinator_ip
149+ set -e
150+
151+ PIDS=()
152+ eval ${COMMAND} &
153+ PID=$!
154+ PIDS+=($PID )
155+
156+ wait_all_success_or_exit " ${PIDS[@]} "
0 commit comments