Skip to content

Commit ce3ee64

Browse files
add sharding contraint to reshape after attn. Use mesh with vae decode.
1 parent 4543686 commit ce3ee64

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def _reshape_heads_to_head_dim(tensor):
9999
# This is used to transform the output of flash attention back into the format of other attention outputs
100100
b, h, s, d = tensor.shape
101101
tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3])
102-
return jnp.reshape(tensor, (b, -1, h * d))
102+
reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d))
103+
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
103104

104105

105106
def _unflatten_heads(tensor, heads):

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,12 +434,13 @@ def __call__(
434434
prompt_embeds=prompt_embeds,
435435
negative_prompt_embeds=negative_prompt_embeds,
436436
)
437-
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
438-
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
439-
latents = latents / latents_std + latents_mean
440-
latents = latents.astype(self.config.weights_dtype)
441-
442-
video = self.vae.decode(latents, self.vae_cache)[0]
437+
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
438+
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
439+
latents = latents / latents_std + latents_mean
440+
latents = latents.astype(self.config.weights_dtype)
441+
442+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
443+
video = self.vae.decode(latents, self.vae_cache)[0]
443444

444445
video = jnp.transpose(video, (0, 4, 1, 2, 3))
445446
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)

0 commit comments

Comments
 (0)