Skip to content

Commit ef956e2

Browse files
committed
Replace hardcoded PartitionSpec in VACE with logical axis mapping
1 parent 4be2899 commit ef956e2

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
import jax.numpy as jnp
24-
from jax.sharding import PartitionSpec
24+
2525

2626
from .... import common_types
2727
from ....configuration_utils import register_to_config
@@ -201,12 +201,12 @@ def __call__(
201201

202202
control_hidden_states = jax.lax.with_sharding_constraint(
203203
control_hidden_states,
204-
PartitionSpec("data", "fsdp", "tensor"),
204+
nn.logical_to_mesh_axes(("activation_batch", "activation_length", None)),
205205
)
206206
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
207207
encoder_hidden_states = jax.lax.with_sharding_constraint(
208208
encoder_hidden_states,
209-
PartitionSpec("data", "fsdp", None),
209+
nn.logical_to_mesh_axes(("activation_batch", "activation_length", None)),
210210
)
211211

212212
# 1. Self-attention

0 commit comments

Comments
 (0)