Skip to content

Commit 85c8433

Browse files
committed
Add debugging printing - will revert
1 parent 5d9d758 commit 85c8433

7 files changed

Lines changed: 153 additions & 1 deletion

File tree

sh_scripts/128_intoverflow.sh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
#!/bin/bash
3+
# bash docker_build_dependency_image.sh
4+
# docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
5+
# docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
6+
CLUSTER_NAME=bodaborg-tpu7x-128
7+
DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256
8+
PROJECT=cloud-tpu-multipod-dev
9+
ZONE=us-central1
10+
11+
# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path.
12+
export RUN_NAME=wan-v7x-128-incre-bring-back-v0
13+
14+
USR_NAME=elisatsai
15+
YOUR_GCS_BUCKET=gs://${USR_NAME}-wan-maxdiffusion
16+
17+
OUTPUT_DIR=${YOUR_GCS_BUCKET}/wan/${RUN_NAME}
18+
19+
# using sanbao's dir
20+
DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
21+
EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
22+
SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
23+
24+
RANDOM=123456789
25+
IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
26+
LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
27+
28+
xpk workload create \
29+
--cluster=$CLUSTER_NAME \
30+
--project=$PROJECT \
31+
--zone=$ZONE \
32+
--device-type=$DEVICE_TYPE \
33+
--num-slices=1 \
34+
--command=" \
35+
pip install . && \
36+
gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \
37+
python -m pip install ${LIBTPU_VERSION} && \
38+
pip install tokamax && \
39+
export XLA_FLAGS='--xla_dump_to=/tmp/xla_dumps --xla_dump_hlo_as_text --xla_dump_hlo_as_proto' && \
40+
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \
41+
--xla_tpu_enable_async_collective_fusion=true \
42+
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
43+
--xla_enable_async_all_reduce=true \
44+
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
45+
--xla_max_concurrent_async_all_gathers=4 \
46+
--xla_tpu_enable_async_all_to_all=true \
47+
--xla_latency_hiding_scheduler_rerun=5 \
48+
--xla_tpu_rwb_fusion=false \
49+
--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \
50+
--xla_tpu_impure_enable_packed_bf16_math_ops=false \
51+
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
52+
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
53+
--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
54+
--xla_tpu_enable_all_gather_offload_tracing=true \
55+
--xla_tpu_use_tc_device_shape_on_sc=true \
56+
--xla_tpu_prefer_async_allgather_to_allreduce=true \
57+
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
58+
--xla_tpu_scoped_vmem_limit_kib=65536 \
59+
--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \
60+
--xla_enable_transpose_trace=false' && \
61+
export HF_TOKEN=<your_token_here> && \
62+
echo 'Starting WAN training ...' && \
63+
HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \
64+
src/maxdiffusion/configs/base_wan_14b.yml \
65+
attention='flash' \
66+
weights_dtype=bfloat16 \
67+
activations_dtype=bfloat16 \
68+
guidance_scale=5.0 \
69+
flow_shift=5.0 \
70+
fps=16 \
71+
skip_jax_distributed_system=False \
72+
run_name='test-wan-training-new' \
73+
output_dir=${OUTPUT_DIR} \
74+
train_data_dir=${DATASET_DIR} \
75+
load_tfrecord_cached=True \
76+
height=1280 \
77+
width=720 \
78+
num_frames=81 \
79+
num_inference_steps=50 \
80+
prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
81+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
82+
enable_profiler=True \
83+
dataset_save_location=${SAVE_DATASET_DIR} \
84+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
85+
flash_min_seq_length=0 \
86+
seed=$RANDOM \
87+
skip_first_n_steps_for_profiler=0 \
88+
profiler_steps=2 \
89+
per_device_batch_size=0.5 \
90+
ici_data_parallelism=64 \
91+
ici_fsdp_parallelism=2 \
92+
ici_tensor_parallelism=1 \
93+
allow_split_physical_axes=True \
94+
max_train_steps=2 \
95+
scan_layers=true \
96+
flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' || (echo 'Training failed, uploading HLO dumps...'; sleep 5; gsutil -m cp -r /tmp/xla_dumps ${OUTPUT_DIR}/hlo_dumps/ 2>&1); \
97+
" \
98+
--base-docker-image=${IMAGE_DIR} \
99+
--enable-debug-logs \
100+
--workload=${RUN_NAME} \
101+
--priority=medium \
102+
--max-restarts=0

sh_scripts/filter_by_job.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
# Color codes
4+
BLUE='\033[0;34m'
5+
GREEN='\033[0;32m'
6+
NC='\033[0m' # No Color
7+
8+
echo -e "${GREEN}xpk workload list --cluster=bodaborg-tpu7x-128 --project=cloud-tpu-multipod-dev --zone=us-central1 --filter-by-job=$RUN_NAME${NC}"
9+
xpk workload list --cluster=bodaborg-tpu7x-128 --project=cloud-tpu-multipod-dev --zone=us-central1 --filter-by-job=$RUN_NAME

sh_scripts/first_pod_log.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
# Color codes
3+
BLUE='\033[0;34m'
4+
GREEN='\033[0;32m'
5+
NC='\033[0m' # No Color
6+
7+
echo -e "${GREEN}kubectl logs $(kubectl get pods | grep $RUN_NAME | head -1 | awk '{print $1}') --all-containers --tail=200${NC}"
8+
kubectl logs $(kubectl get pods | grep $RUN_NAME | head -1 | awk '{print $1}') --all-containers --tail=200

sh_scripts/get_pods.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
# Color codes
3+
BLUE='\033[0;34m'
4+
GREEN='\033[0;32m'
5+
NC='\033[0m' # No Color
6+
7+
echo -e "${GREEN}kubectl get pods | grep $RUN_NAME${NC}"
8+
kubectl get pods | grep $RUN_NAME

sh_scripts/linter_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# bash unit_test_and_lint.sh
2+
ruff check . --fix

src/maxdiffusion/max_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,4 +650,23 @@ def maybe_initialize_jax_distributed_system(raw_keys):
650650
initialize_jax_for_gpu()
651651
max_logging.log("Jax distributed system initialized on GPU!")
652652
else:
653-
jax.distributed.initialize()
653+
jax.distributed.initialize()
654+
655+
656+
def get_tensor_sharding_info(tensor, name="tensor"):
657+
"""Print tensor sharding info at runtime using jax.debug.print.
658+
659+
Args:
660+
tensor: JAX array to inspect
661+
name: Human-readable name for the tensor (for logging)
662+
663+
Returns:
664+
The sharding of the tensor, or None if error occurs
665+
"""
666+
try:
667+
sharding = tensor.sharding
668+
jax.debug.print(f"[SHARDING] {name}: shape={tensor.shape}, sharding={sharding}")
669+
return sharding
670+
except Exception as e:
671+
jax.debug.print(f"[SHARDING] {name}: ERROR getting sharding info - {e}")
672+
return None

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ...normalization_flax import FP32LayerNorm
3737
from ...attention_flax import FlaxWanAttention
3838
from ...gradient_checkpoint import GradientCheckpointType
39+
from ....max_utils import get_tensor_sharding_info
3940

4041
BlockSizes = common_types.BlockSizes
4142

@@ -187,6 +188,7 @@ def __init__(
187188

188189
def __call__(self, x: jax.Array) -> jax.Array:
189190
x = self.proj(x)
191+
get_tensor_sharding_info(x, "ffn_activation_before_gelu")
190192
return nnx.gelu(x)
191193

192194

@@ -246,6 +248,7 @@ def conditional_named_scope(self, name: str):
246248
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
247249
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
248250
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
251+
get_tensor_sharding_info(hidden_states, "ffn_output_after_gelu")
249252
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250253
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251254
with self.conditional_named_scope("mlp_down_proj"):
@@ -359,6 +362,7 @@ def __call__(
359362
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
360363
)
361364
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
365+
get_tensor_sharding_info(hidden_states, "hidden_states_after_block_entry_constraint")
362366
hidden_states = checkpoint_name(hidden_states, "hidden_states")
363367
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
364368

0 commit comments

Comments
 (0)