Skip to content

Commit 8f1ffde

Browse files
committed
Add debugging printing - will revert
1 parent 5d9d758 commit 8f1ffde

7 files changed

Lines changed: 161 additions & 1 deletion

File tree

sh_scripts/128_intoverflow.sh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
#!/bin/bash
3+
# bash docker_build_dependency_image.sh
4+
# docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
5+
# docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
6+
CLUSTER_NAME=bodaborg-tpu7x-128
7+
DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256
8+
PROJECT=cloud-tpu-multipod-dev
9+
ZONE=us-central1
10+
11+
# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path.
12+
export RUN_NAME=wan-v7x-128-incre-bring-back-v0
13+
14+
USR_NAME=elisatsai
15+
YOUR_GCS_BUCKET=gs://${USR_NAME}-wan-maxdiffusion
16+
17+
OUTPUT_DIR=${YOUR_GCS_BUCKET}/wan/${RUN_NAME}
18+
19+
# using sanbao's dir
20+
DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
21+
EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
22+
SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
23+
24+
RANDOM=123456789
25+
IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
26+
LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
27+
28+
xpk workload create \
29+
--cluster=$CLUSTER_NAME \
30+
--project=$PROJECT \
31+
--zone=$ZONE \
32+
--device-type=$DEVICE_TYPE \
33+
--num-slices=1 \
34+
--command=" \
35+
pip install . && \
36+
gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \
37+
python -m pip install ${LIBTPU_VERSION} && \
38+
pip install tokamax && \
39+
export XLA_FLAGS='--xla_dump_to=/tmp/xla_dumps --xla_dump_hlo_as_text --xla_dump_hlo_as_proto' && \
40+
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \
41+
--xla_tpu_enable_async_collective_fusion=true \
42+
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
43+
--xla_enable_async_all_reduce=true \
44+
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
45+
--xla_max_concurrent_async_all_gathers=4 \
46+
--xla_tpu_enable_async_all_to_all=true \
47+
--xla_latency_hiding_scheduler_rerun=5 \
48+
--xla_tpu_rwb_fusion=false \
49+
--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \
50+
--xla_tpu_impure_enable_packed_bf16_math_ops=false \
51+
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
52+
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
53+
--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
54+
--xla_tpu_enable_all_gather_offload_tracing=true \
55+
--xla_tpu_use_tc_device_shape_on_sc=true \
56+
--xla_tpu_prefer_async_allgather_to_allreduce=true \
57+
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
58+
--xla_tpu_scoped_vmem_limit_kib=65536 \
59+
--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \
60+
--xla_enable_transpose_trace=false' && \
61+
export HF_TOKEN=<your_token_here> && \
62+
echo 'Starting WAN training ...' && \
63+
HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \
64+
src/maxdiffusion/configs/base_wan_14b.yml \
65+
attention='flash' \
66+
weights_dtype=bfloat16 \
67+
activations_dtype=bfloat16 \
68+
guidance_scale=5.0 \
69+
flow_shift=5.0 \
70+
fps=16 \
71+
skip_jax_distributed_system=False \
72+
run_name='test-wan-training-new' \
73+
output_dir=${OUTPUT_DIR} \
74+
train_data_dir=${DATASET_DIR} \
75+
load_tfrecord_cached=True \
76+
height=1280 \
77+
width=720 \
78+
num_frames=81 \
79+
num_inference_steps=50 \
80+
prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
81+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
82+
enable_profiler=True \
83+
dataset_save_location=${SAVE_DATASET_DIR} \
84+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
85+
flash_min_seq_length=0 \
86+
seed=$RANDOM \
87+
skip_first_n_steps_for_profiler=0 \
88+
profiler_steps=2 \
89+
per_device_batch_size=0.5 \
90+
ici_data_parallelism=64 \
91+
ici_fsdp_parallelism=2 \
92+
ici_tensor_parallelism=1 \
93+
allow_split_physical_axes=True \
94+
max_train_steps=2 \
95+
scan_layers=true \
96+
flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' || (echo 'Training failed, uploading HLO dumps...'; sleep 5; gsutil -m cp -r /tmp/xla_dumps ${OUTPUT_DIR}/hlo_dumps/ 2>&1); \
97+
" \
98+
--base-docker-image=${IMAGE_DIR} \
99+
--enable-debug-logs \
100+
--workload=${RUN_NAME} \
101+
--priority=medium \
102+
--max-restarts=0

sh_scripts/filter_by_job.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
# Color codes
4+
BLUE='\033[0;34m'
5+
GREEN='\033[0;32m'
6+
NC='\033[0m' # No Color
7+
8+
echo -e "${GREEN}xpk workload list --cluster=bodaborg-tpu7x-128 --project=cloud-tpu-multipod-dev --zone=us-central1 --filter-by-job=$RUN_NAME${NC}"
9+
xpk workload list --cluster=bodaborg-tpu7x-128 --project=cloud-tpu-multipod-dev --zone=us-central1 --filter-by-job=$RUN_NAME

sh_scripts/first_pod_log.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
# Color codes
3+
BLUE='\033[0;34m'
4+
GREEN='\033[0;32m'
5+
NC='\033[0m' # No Color
6+
7+
echo -e "${GREEN}kubectl logs $(kubectl get pods | grep $RUN_NAME | head -1 | awk '{print $1}') --all-containers --tail=200${NC}"
8+
kubectl logs $(kubectl get pods | grep $RUN_NAME | head -1 | awk '{print $1}') --all-containers --tail=200

sh_scripts/get_pods.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
# Color codes
3+
BLUE='\033[0;34m'
4+
GREEN='\033[0;32m'
5+
NC='\033[0m' # No Color
6+
7+
echo -e "${GREEN}kubectl get pods | grep $RUN_NAME${NC}"
8+
kubectl get pods | grep $RUN_NAME

sh_scripts/linter_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# bash unit_test_and_lint.sh
2+
ruff check . --fix

src/maxdiffusion/max_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,4 +650,30 @@ def maybe_initialize_jax_distributed_system(raw_keys):
650650
initialize_jax_for_gpu()
651651
max_logging.log("Jax distributed system initialized on GPU!")
652652
else:
653-
jax.distributed.initialize()
653+
jax.distributed.initialize()
654+
655+
656+
def get_tensor_sharding_info(tensor, name="tensor", loc=""):
657+
"""Print tensor sharding info using jax.debug.inspect_array_sharding.
658+
659+
This function uses jax.debug.inspect_array_sharding which prints sharding
660+
metadata without transferring the tensor data to host (avoiding OOM).
661+
662+
Args:
663+
tensor: JAX array to inspect
664+
name: Human-readable name for the tensor (for logging)
665+
loc: Location string (e.g., "MLP_INPUT", "FFN_OUTPUT") to identify where sharding is checked
666+
667+
Returns:
668+
The tensor unchanged (for use in chaining)
669+
"""
670+
# jax.debug.inspect_array_sharding only prints metadata, not the data
671+
# This avoids OOM issues with large tensors
672+
loc_str = f" [{loc}]" if loc else ""
673+
674+
# Create a custom callback that prefixes location info to each sharding line
675+
def print_with_loc(msg):
676+
print(f"[{loc}] {msg}" if loc else msg)
677+
678+
jax.debug.inspect_array_sharding(tensor, callback=print_with_loc)
679+
return tensor

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ...normalization_flax import FP32LayerNorm
3737
from ...attention_flax import FlaxWanAttention
3838
from ...gradient_checkpoint import GradientCheckpointType
39+
from ....max_utils import get_tensor_sharding_info
3940

4041
BlockSizes = common_types.BlockSizes
4142

@@ -187,6 +188,7 @@ def __init__(
187188

188189
def __call__(self, x: jax.Array) -> jax.Array:
189190
x = self.proj(x)
191+
get_tensor_sharding_info(x, "ffn_activation_before_gelu", loc="GELU_PROJ_OUTPUT")
190192
return nnx.gelu(x)
191193

192194

@@ -245,7 +247,9 @@ def conditional_named_scope(self, name: str):
245247

246248
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
247249
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
250+
get_tensor_sharding_info(hidden_states, "mlp_input", loc="MLP_INPUT")
248251
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
252+
get_tensor_sharding_info(hidden_states, "mlp_intermediate", loc="MLP_INTERMEDIATE")
249253
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250254
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251255
with self.conditional_named_scope("mlp_down_proj"):
@@ -359,6 +363,7 @@ def __call__(
359363
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
360364
)
361365
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
366+
get_tensor_sharding_info(hidden_states, "hidden_states_after_constraint", loc="BLOCK_ENTRY")
362367
hidden_states = checkpoint_name(hidden_states, "hidden_states")
363368
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
364369

0 commit comments

Comments
 (0)