Skip to content

Commit 229c7d9

Browse files
committed
working code
1 parent 01ba7c0 commit 229c7d9

8 files changed

Lines changed: 319 additions & 81 deletions

File tree

118 KB
Loading
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
CONFIG="src/maxdiffusion/configs/base_wan_animate_27b.yml"
5+
VENV_PATH="/data/maxdiffusion-work/maxdiffusion-venv"
6+
OUTPUT_ROOT="/data/maxdiffusion-work/outputs/wan-animate-sweeps"
7+
RUN_PREFIX="wananimate-sweep"
8+
ENABLE_PROFILER="True"
9+
SKIP_JAX_DISTRIBUTED_SYSTEM="True"
10+
NUM_FRAMES_OVERRIDE=""
11+
DRY_RUN=0
12+
declare -a SCENARIOS=()
13+
declare -a EXTRA_OVERRIDES=()
14+
15+
usage() {
16+
cat <<'EOF'
17+
Usage:
18+
scripts/run_wan_animate_parallelism_sweep.sh [options]
19+
20+
Runs a sequence of WAN Animate inference jobs with different parallelism layouts.
21+
Each run gets its own output directory, TensorBoard trace directory, and log file.
22+
23+
Default scenarios:
24+
context8:1:1:8:1
25+
context4_tensor2:1:1:4:2
26+
context2_tensor4:1:1:2:4
27+
context2_fsdp4:1:4:2:1
28+
fsdp4_tensor2:1:4:1:2
29+
cp2_fsdp2_tp2:1:2:2:2
30+
fsdp8:1:8:1:1
31+
32+
Scenario format:
33+
name:data:fsdp:context:tensor
34+
35+
Options:
36+
--config <path> Config file to run
37+
--venv <path> Virtualenv root containing bin/python
38+
--output-root <path> Root folder for the sweep session
39+
--run-prefix <prefix> Prefix used in run names and session directory
40+
--num-frames <int> Override num_frames for all runs
41+
--scenario <spec> Add a scenario; may be repeated
42+
--extra-override <k=v> Extra config override; may be repeated
43+
--no-profiler Disable profiler for all runs
44+
--dry-run Print commands without executing them
45+
-h, --help Show this help
46+
47+
Example:
48+
scripts/run_wan_animate_parallelism_sweep.sh \
49+
--run-prefix wananimate-xprof \
50+
--scenario context8:1:1:8:1 \
51+
--scenario context4_tensor2:1:1:4:2
52+
EOF
53+
}
54+
55+
while [[ $# -gt 0 ]]; do
56+
case "$1" in
57+
--config)
58+
CONFIG="${2:-}"
59+
shift 2
60+
;;
61+
--venv)
62+
VENV_PATH="${2:-}"
63+
shift 2
64+
;;
65+
--output-root)
66+
OUTPUT_ROOT="${2:-}"
67+
shift 2
68+
;;
69+
--run-prefix)
70+
RUN_PREFIX="${2:-}"
71+
shift 2
72+
;;
73+
--num-frames)
74+
NUM_FRAMES_OVERRIDE="${2:-}"
75+
shift 2
76+
;;
77+
--scenario)
78+
SCENARIOS+=("${2:-}")
79+
shift 2
80+
;;
81+
--extra-override)
82+
EXTRA_OVERRIDES+=("${2:-}")
83+
shift 2
84+
;;
85+
--no-profiler)
86+
ENABLE_PROFILER="False"
87+
shift
88+
;;
89+
--dry-run)
90+
DRY_RUN=1
91+
shift
92+
;;
93+
-h|--help)
94+
usage
95+
exit 0
96+
;;
97+
*)
98+
echo "Unknown option: $1" >&2
99+
usage
100+
exit 1
101+
;;
102+
esac
103+
done
104+
105+
if [[ ${#SCENARIOS[@]} -eq 0 ]]; then
106+
SCENARIOS=(
107+
"context8:1:1:8:1"
108+
"context4_tensor2:1:1:4:2"
109+
"context2_tensor4:1:1:2:4"
110+
"context2_fsdp4:1:4:2:1"
111+
"fsdp4_tensor2:1:4:1:2"
112+
"cp2_fsdp2_tp2:1:2:2:2"
113+
"fsdp8:1:8:1:1"
114+
)
115+
fi
116+
117+
if [[ ! -f "${CONFIG}" ]]; then
118+
echo "Config not found: ${CONFIG}" >&2
119+
exit 1
120+
fi
121+
122+
if [[ ! -x "${VENV_PATH}/bin/python" ]]; then
123+
echo "Python not found in venv: ${VENV_PATH}/bin/python" >&2
124+
exit 1
125+
fi
126+
127+
SESSION_ID="$(date -u +%Y%m%d-%H%M%S)"
128+
SESSION_ROOT="${OUTPUT_ROOT%/}/${RUN_PREFIX}-${SESSION_ID}"
129+
LOG_DIR="${SESSION_ROOT}/logs"
130+
mkdir -p "${LOG_DIR}"
131+
132+
SUMMARY_TSV="${SESSION_ROOT}/summary.tsv"
133+
COMMANDS_SH="${SESSION_ROOT}/commands.sh"
134+
135+
{
136+
printf "scenario\trun_name\tstatus\tduration_seconds\toutput_dir\ttensorboard_dir\tlog_file\n"
137+
} > "${SUMMARY_TSV}"
138+
139+
{
140+
echo "#!/usr/bin/env bash"
141+
echo "set -euo pipefail"
142+
echo
143+
echo "# Generated by scripts/run_wan_animate_parallelism_sweep.sh"
144+
echo "# Session root: ${SESSION_ROOT}"
145+
echo
146+
} > "${COMMANDS_SH}"
147+
148+
echo "Sweep session root: ${SESSION_ROOT}"
149+
echo "Summary file: ${SUMMARY_TSV}"
150+
echo "Commands file: ${COMMANDS_SH}"
151+
echo "TensorBoard root: ${SESSION_ROOT}"
152+
echo
153+
154+
for scenario_spec in "${SCENARIOS[@]}"; do
155+
IFS=":" read -r scenario_name data_parallelism fsdp_parallelism context_parallelism tensor_parallelism <<< "${scenario_spec}"
156+
if [[ -z "${scenario_name}" || -z "${data_parallelism}" || -z "${fsdp_parallelism}" || -z "${context_parallelism}" || -z "${tensor_parallelism}" ]]; then
157+
echo "Invalid scenario: ${scenario_spec}" >&2
158+
echo "Expected format: name:data:fsdp:context:tensor" >&2
159+
exit 1
160+
fi
161+
162+
run_name="${RUN_PREFIX}-${scenario_name}"
163+
run_output_dir="${SESSION_ROOT}/artifacts"
164+
run_tensorboard_dir="${run_output_dir}/${run_name}/tensorboard"
165+
log_file="${LOG_DIR}/${run_name}.log"
166+
167+
cmd=(
168+
"${VENV_PATH}/bin/python"
169+
"src/maxdiffusion/generate_wan_animate.py"
170+
"${CONFIG}"
171+
"run_name=${run_name}"
172+
"output_dir=${run_output_dir}"
173+
"enable_profiler=${ENABLE_PROFILER}"
174+
"skip_jax_distributed_system=${SKIP_JAX_DISTRIBUTED_SYSTEM}"
175+
"ici_data_parallelism=${data_parallelism}"
176+
"ici_fsdp_parallelism=${fsdp_parallelism}"
177+
"ici_context_parallelism=${context_parallelism}"
178+
"ici_tensor_parallelism=${tensor_parallelism}"
179+
)
180+
181+
if [[ -n "${NUM_FRAMES_OVERRIDE}" ]]; then
182+
cmd+=("num_frames=${NUM_FRAMES_OVERRIDE}")
183+
fi
184+
185+
for extra_override in "${EXTRA_OVERRIDES[@]}"; do
186+
cmd+=("${extra_override}")
187+
done
188+
189+
printf "%q " "${cmd[@]}" >> "${COMMANDS_SH}"
190+
printf "\n\n" >> "${COMMANDS_SH}"
191+
192+
echo "========================================================================"
193+
echo "Scenario: ${scenario_name}"
194+
echo "Run name: ${run_name}"
195+
echo "Parallelism: data=${data_parallelism}, fsdp=${fsdp_parallelism}, context=${context_parallelism}, tensor=${tensor_parallelism}"
196+
echo "Output dir: ${run_output_dir}/${run_name}"
197+
echo "TensorBoard dir: ${run_tensorboard_dir}"
198+
echo "Log file: ${log_file}"
199+
echo "========================================================================"
200+
201+
if [[ "${DRY_RUN}" -eq 1 ]]; then
202+
printf "DRY RUN: "
203+
printf "%q " "${cmd[@]}"
204+
printf "\n\n"
205+
printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \
206+
"${scenario_name}" \
207+
"${run_name}" \
208+
"DRY_RUN" \
209+
"0" \
210+
"${run_output_dir}/${run_name}" \
211+
"${run_tensorboard_dir}" \
212+
"${log_file}" >> "${SUMMARY_TSV}"
213+
continue
214+
fi
215+
216+
start_ts="$(date +%s)"
217+
set +e
218+
(
219+
echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] Starting ${run_name}"
220+
printf "Command: "
221+
printf "%q " "${cmd[@]}"
222+
printf "\n"
223+
"${cmd[@]}"
224+
) 2>&1 | tee "${log_file}"
225+
exit_code=${PIPESTATUS[0]}
226+
set -e
227+
end_ts="$(date +%s)"
228+
duration="$((end_ts - start_ts))"
229+
230+
if [[ "${exit_code}" -eq 0 ]]; then
231+
status="OK"
232+
else
233+
status="FAIL(${exit_code})"
234+
fi
235+
236+
printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \
237+
"${scenario_name}" \
238+
"${run_name}" \
239+
"${status}" \
240+
"${duration}" \
241+
"${run_output_dir}/${run_name}" \
242+
"${run_tensorboard_dir}" \
243+
"${log_file}" >> "${SUMMARY_TSV}"
244+
245+
echo
246+
done
247+
248+
chmod +x "${COMMANDS_SH}"
249+
250+
echo "Sweep complete."
251+
echo "Summary: ${SUMMARY_TSV}"
252+
echo "Commands: ${COMMANDS_SH}"
253+
echo "To inspect traces tomorrow:"
254+
echo " tensorboard --logdir=${SESSION_ROOT}/artifacts"

src/maxdiffusion/configs/base_wan_animate_27b.yml

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
156156
# conv_in : conv.shape[2] weight
157157
# conv_out : conv.shape[-1] weight
158158
logical_axis_rules: [
159-
['batch', 'data'],
160-
['activation_batch', 'data'],
159+
['batch', ['data', 'fsdp']],
160+
['activation_batch', ['data', 'fsdp']],
161161
['activation_self_attn_heads', ['context', 'tensor']],
162-
['activation_self_attn_q_length', 'context'],
163-
['activation_self_attn_kv_length', None],
164-
['activation_cross_attn_q_length', 'context'],
165-
['activation_cross_attn_heads', 'tensor'],
166-
['activation_cross_attn_kv_length', None],
162+
['activation_cross_attn_q_length', ['context', 'tensor']],
167163
['activation_length', 'context'],
168164
['activation_heads', 'tensor'],
169-
['activation_kv', 'tensor'],
170165
['mlp','tensor'],
171-
['embed', 'fsdp'],
166+
['embed', ['context', 'fsdp']],
172167
['heads', 'tensor'],
173168
['norm', 'tensor'],
174-
['conv_batch', ['data', 'context']],
169+
['conv_batch', ['data', 'context', 'fsdp']],
175170
['out_channels', 'tensor'],
176171
['conv_out', 'context'],
177172
]
@@ -183,12 +178,12 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
183178
# and product of the ICI axes should equal number of devices per slice.
184179
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
185180
dcn_fsdp_parallelism: 1
186-
dcn_context_parallelism: 1
181+
dcn_context_parallelism: -1
187182
dcn_tensor_parallelism: 1
188183
ici_data_parallelism: 1
189-
ici_fsdp_parallelism: 4
190-
ici_context_parallelism: 1
191-
ici_tensor_parallelism: 2
184+
ici_fsdp_parallelism: 1
185+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
186+
ici_tensor_parallelism: 1
192187

193188
allow_split_physical_axes: False
194189

@@ -294,7 +289,7 @@ negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles,
294289
do_classifier_free_guidance: True
295290
height: 720
296291
width: 1280
297-
num_frames: 81
292+
num_frames: 121
298293
flow_shift: 5.0
299294

300295
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py

src/maxdiffusion/generate_wan_animate.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ def _get_animate_inference_settings(config):
3939
"guidance_scale": getattr(config, "animate_guidance_scale", 1.0),
4040
}
4141

42+
43+
def _frame_summary(name, frames):
44+
"""Return a compact frame-count/size summary for logging."""
45+
if not frames:
46+
return f"{name}_frames=0"
47+
return f"{name}_frames={len(frames)}, {name}_frame_size={getattr(frames[0], 'size', None)}"
48+
49+
4250
def run(config):
4351
writer = max_utils.initialize_summary_writer(config)
4452
if jax.process_index() == 0 and writer:
@@ -53,9 +61,11 @@ def run(config):
5361
reference_image_path = getattr(config, "reference_image_path", "")
5462
if reference_image_path:
5563
image = load_image(reference_image_path)
64+
reference_image_source = reference_image_path
5665
else:
5766
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
5867
image = load_image(image_url)
68+
reference_image_source = image_url
5969

6070
mode = getattr(config, "mode", "animate")
6171
pose_video_path = getattr(config, "pose_video_path", "")
@@ -98,6 +108,28 @@ def run(config):
98108
background_video = load_video(background_video_path)[:num_frames]
99109
mask_video = load_video(mask_video_path)[:num_frames]
100110

111+
max_logging.log(
112+
"Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s"
113+
% (
114+
reference_image_source,
115+
getattr(image, "size", None),
116+
pose_video_path or "<dummy>",
117+
face_video_path or "<dummy>",
118+
_frame_summary("pose", pose_video),
119+
_frame_summary("face", face_video),
120+
)
121+
)
122+
if mode == "replace":
123+
max_logging.log(
124+
"Wan replace inputs: background_video_path=%s, mask_video_path=%s, %s, %s"
125+
% (
126+
background_video_path,
127+
mask_video_path,
128+
_frame_summary("background", background_video),
129+
_frame_summary("mask", mask_video),
130+
)
131+
)
132+
101133
animate_settings = _get_animate_inference_settings(config)
102134
prompt = config.prompt
103135
negative_prompt = config.negative_prompt if animate_settings["guidance_scale"] > 1.0 else None

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import jax
2121
import jax.numpy as jnp
2222
from jax import tree_util
23-
from jax.sharding import NamedSharding, PartitionSpec as P
2423
from flax import nnx
2524
from ...configuration_utils import ConfigMixin
2625
from ..modeling_flax_utils import FlaxModelMixin, get_activation
@@ -100,10 +99,10 @@ def __init__(
10099

101100
self.mesh = mesh
102101
# Set sharding dynamically based on out_channels.
103-
num_vae_spatial_devices = mesh.shape["vae_spatial"]
102+
num_context_axis_devices = mesh.shape["context"]
104103
kernel_sharding = (None, None, None, None, None)
105-
if out_channels % num_vae_spatial_devices == 0:
106-
kernel_sharding = (None, None, None, None, "vae_spatial")
104+
if out_channels % num_context_axis_devices == 0:
105+
kernel_sharding = (None, None, None, None, "conv_out")
107106

108107
self.conv = nnx.Conv(
109108
in_features=in_channels,
@@ -120,8 +119,6 @@ def __init__(
120119
)
121120

122121
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
123-
# Shard the widest activation dimension across the dedicated VAE mesh.
124-
x = jax.lax.with_sharding_constraint(x, NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)))
125122
current_padding = list(self._causal_padding) # Mutable copy
126123
padding_needed = self._depth_padding_before
127124

0 commit comments

Comments
 (0)