Skip to content

Commit 51e1db1

Browse files
authored
Support building GPU docker image for MaxDiffusion Model (#121)
* adding gpu docker file * add more gpu dependency files, unify working directory to match xpk setup, add hardware gpu option in yml, add jax multi-host support for gpu * fix identation * reformatting * add gpu_multi_process_run.sh, unify working directory, update requirement to fix import error, add jax[cuda] install instruction more non-pinned mode when device is GPU * reformatting * resolve comments * delete gpu pinned mode
1 parent 7ca0857 commit 51e1db1

14 files changed

Lines changed: 367 additions & 69 deletions

docker_build_dependency_image.sh

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ if [[ -z ${MODE} ]]; then
4646
echo "Default MODE=${MODE}"
4747
fi
4848

49+
if [[ -z ${DEVICE} ]]; then
50+
export DEVICE=tpu
51+
echo "Default DEVICE=${DEVICE}"
52+
fi
53+
echo "DEVICE=${DEVICE}"
54+
4955
if [[ -z ${JAX_VERSION+x} ]] ; then
5056
export JAX_VERSION=NONE
5157
echo "Default JAX_VERSION=${JAX_VERSION}"
@@ -55,22 +61,31 @@ COMMIT_HASH=$(git rev-parse --short HEAD)
5561

5662
echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ."
5763

58-
if [[ "${MODE}" == "stable_stack" ]]; then
59-
if [[ ! -v BASEIMAGE ]]; then
60-
echo "Erroring out because BASEIMAGE is unset, please set it!"
61-
exit 1
64+
if [[ ${DEVICE} == "gpu" ]]; then
65+
if [[ ${MODE} == "pinned" ]]; then
66+
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17
67+
else
68+
export BASEIMAGE=ghcr.io/nvidia/jax:base
69+
fi
70+
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
71+
else
72+
if [[ "${MODE}" == "stable_stack" ]]; then
73+
if [[ ! -v BASEIMAGE ]]; then
74+
echo "Erroring out because BASEIMAGE is unset, please set it!"
75+
exit 1
76+
fi
77+
docker build --no-cache \
78+
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
79+
--build-arg COMMIT_HASH=${COMMIT_HASH} \
80+
--network=host \
81+
-t ${LOCAL_IMAGE_NAME} \
82+
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
83+
else
84+
docker build --no-cache \
85+
--network=host \
86+
--build-arg MODE=${MODE} \
87+
--build-arg JAX_VERSION=${JAX_VERSION} \
88+
-t ${LOCAL_IMAGE_NAME} \
89+
-f maxdiffusion_dependencies.Dockerfile .
6290
fi
63-
docker build --no-cache \
64-
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
65-
--build-arg COMMIT_HASH=${COMMIT_HASH} \
66-
--network=host \
67-
-t ${LOCAL_IMAGE_NAME} \
68-
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
69-
else
70-
docker build --no-cache \
71-
--network=host \
72-
--build-arg MODE=${MODE} \
73-
--build-arg JAX_VERSION=${JAX_VERSION} \
74-
-t ${LOCAL_IMAGE_NAME} \
75-
-f maxdiffusion_dependencies.Dockerfile .
7691
fi

gpu_multi_process_run.sh

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#! /bin/bash
2+
set -e
3+
set -u
4+
set -o pipefail
5+
6+
: "${NNODES:?Must set NNODES}"
7+
: "${NODE_RANK:?Must set NODE_RANK}"
8+
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}"
9+
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}"
10+
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}"
11+
: "${COMMAND:?Must set COMMAND}"
12+
13+
14+
export GPUS_PER_NODE=$GPUS_PER_NODE
15+
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT
16+
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS
17+
18+
set_nccl_gpudirect_tcpx_specific_configuration() {
19+
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
20+
export CUDA_DEVICE_MAX_CONNECTIONS=1
21+
export NCCL_CROSS_NIC=0
22+
export NCCL_DEBUG=INFO
23+
export NCCL_DYNAMIC_CHUNK_SIZE=524288
24+
export NCCL_NET_GDR_LEVEL=PIX
25+
export NCCL_NVLS_ENABLE=0
26+
export NCCL_P2P_NET_CHUNKSIZE=524288
27+
export NCCL_P2P_NVL_CHUNKSIZE=1048576
28+
export NCCL_P2P_PCI_CHUNKSIZE=524288
29+
export NCCL_PROTO=Simple
30+
export NCCL_SOCKET_IFNAME=eth0
31+
export NVTE_FUSED_ATTN=1
32+
export TF_CPP_MAX_LOG_LEVEL=100
33+
export TF_CPP_VMODULE=profile_guided_latency_estimator=10
34+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
35+
shopt -s globstar nullglob
36+
IFS=:$IFS
37+
set -- /usr/local/cuda-*/compat
38+
export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
39+
IFS=${IFS#?}
40+
shopt -u globstar nullglob
41+
42+
if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then
43+
echo "Using GPUDirect-TCPX"
44+
export NCCL_ALGO=Ring
45+
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
46+
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
47+
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
48+
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
49+
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
50+
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
51+
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
52+
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
53+
export NCCL_MAX_NCHANNELS=12
54+
export NCCL_MIN_NCHANNELS=12
55+
export NCCL_NSOCKS_PERTHREAD=4
56+
export NCCL_P2P_PXN_LEVEL=0
57+
export NCCL_SOCKET_NTHREADS=1
58+
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
59+
echo "Using GPUDirect-TCPFasTrak"
60+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
61+
export NCCL_ALGO=Ring,Tree
62+
export NCCL_BUFFSIZE=8388608
63+
export NCCL_FASTRAK_CTRL_DEV=eth0
64+
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
65+
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
66+
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
67+
export NCCL_FASTRAK_NUM_FLOWS=2
68+
export NCCL_FASTRAK_USE_LLCM=1
69+
export NCCL_FASTRAK_USE_SNAP=1
70+
export NCCL_MIN_NCHANNELS=4
71+
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
72+
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
73+
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
74+
fi
75+
else
76+
echo "NOT using GPUDirect"
77+
fi
78+
}
79+
80+
echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}"
81+
82+
set_nccl_gpudirect_tcpx_specific_configuration
83+
84+
wait_all_success_or_exit() {
85+
# https://www.baeldung.com/linux/background-process-get-exit-code
86+
local pids=("$@")
87+
while [[ ${#pids[@]} -ne 0 ]]; do
88+
all_success="true"
89+
for pid in "${pids[@]}"; do
90+
code=$(non_blocking_wait "$pid")
91+
if [[ $code -ne 127 ]]; then
92+
if [[ $code -ne 0 ]]; then
93+
echo "PID $pid failed with exit code $code"
94+
exit "$code"
95+
fi
96+
else
97+
all_success="false"
98+
fi
99+
done
100+
if [[ $all_success == "true" ]]; then
101+
echo "All pids succeeded"
102+
break
103+
fi
104+
sleep 5
105+
done
106+
}
107+
non_blocking_wait() {
108+
# https://www.baeldung.com/linux/background-process-get-exit-code
109+
local pid=$1
110+
local code=127 # special code to indicate not-finished
111+
if [[ ! -d "/proc/$pid" ]]; then
112+
wait "$pid"
113+
code=$?
114+
fi
115+
echo $code
116+
}
117+
118+
resolve_coordinator_ip() {
119+
local lookup_attempt=1
120+
local max_coordinator_lookups=500
121+
local coordinator_found=false
122+
local coordinator_ip_address=""
123+
124+
echo "Coordinator Address $JAX_COORDINATOR_ADDRESS"
125+
126+
while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do
127+
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1)
128+
if [[ -n "$coordinator_ip_address" ]]; then
129+
coordinator_found=true
130+
echo "Coordinator IP address: $coordinator_ip_address"
131+
export JAX_COORDINATOR_IP=$coordinator_ip_address
132+
return 0
133+
else
134+
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..."
135+
((lookup_attempt++))
136+
sleep 1
137+
fi
138+
done
139+
140+
if [[ "$coordinator_found" = false ]]; then
141+
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts."
142+
return 1
143+
fi
144+
}
145+
146+
# Resolving coordinator IP
147+
set +e
148+
resolve_coordinator_ip
149+
set -e
150+
151+
PIDS=()
152+
eval ${COMMAND} &
153+
PID=$!
154+
PIDS+=($PID)
155+
156+
wait_all_success_or_exit "${PIDS[@]}"

maxdiffusion_dependencies.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ ARG JAX_VERSION
4545
ENV ENV_JAX_VERSION=$JAX_VERSION
4646

4747
# Set the working directory in the container
48-
WORKDIR /app
48+
WORKDIR /deps
4949

5050
# Copy all files from local workspace into docker container
5151
COPY . .
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# syntax=docker/dockerfile:experimental
2+
# Note: This pulls in the lastest of jax:base
3+
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
4+
FROM $BASEIMAGE
5+
6+
# Stopgaps measure to circumvent gpg key setup issue.
7+
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list
8+
9+
# Install dependencies for adjusting network rto
10+
RUN apt-get update && apt-get install -y iproute2 ethtool lsof
11+
12+
# Install DNS util dependencies
13+
RUN apt-get install -y dnsutils
14+
15+
# Add the Google Cloud SDK package repository
16+
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
17+
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
18+
19+
# Install the Google Cloud SDK
20+
RUN apt-get update && apt-get install -y google-cloud-sdk
21+
22+
# Set environment variables for Google Cloud SDK
23+
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
24+
25+
# Upgrade libcusprase to work with Jax
26+
RUN apt-get update && apt-get install -y libcusparse-12-3
27+
28+
ARG MODE
29+
ENV ENV_MODE=$MODE
30+
31+
ARG JAX_VERSION
32+
ENV ENV_JAX_VERSION=$JAX_VERSION
33+
34+
ARG DEVICE
35+
ENV ENV_DEVICE=$DEVICE
36+
37+
RUN mkdir -p /deps
38+
39+
# Set the working directory in the container
40+
WORKDIR /deps
41+
42+
# Copy all files from local workspace into docker container
43+
COPY . .
44+
RUN ls .
45+
46+
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
47+
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
48+
49+
50+
WORKDIR /deps

maxdiffusion_jax_stable_stack_tpu.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ ARG COMMIT_HASH
77

88
ENV COMMIT_HASH=$COMMIT_HASH
99

10-
RUN mkdir -p /app
10+
RUN mkdir -p /deps
1111

1212
# Set the working directory in the container
13-
WORKDIR /app
13+
WORKDIR /deps
1414

1515
# Copy all files from local workspace into docker container
1616
COPY . .

maxdiffusion_runner.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ ARG BASEIMAGE=maxdiffusion_base_image
22
FROM $BASEIMAGE
33

44
# Set the working directory in the container
5-
WORKDIR /app
5+
WORKDIR /deps
66

77
# Copy all files from local workspace into docker container
88
COPY . .
99

10-
WORKDIR /app
10+
WORKDIR /deps

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint>=0.5.20
2828
tokenizers==0.20.0
29+
huggingface_hub==0.24.7

0 commit comments

Comments
 (0)