Skip to content

Commit 2388908

Browse files
use collapse instead of reshape for final activation.
1 parent b7c8ba6 commit 2388908

2 files changed

Lines changed: 35 additions & 25 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,14 @@ def __call__(
447447
return_dict: bool = True,
448448
attention_kwargs: Optional[Dict[str, Any]] = None,
449449
) -> Union[jax.Array, Dict[str, jax.Array]]:
450-
batch_size, num_frames, height, width, num_channels = hidden_states.shape
450+
batch_size, _, num_frames, height, width = hidden_states.shape
451451
p_t, p_h, p_w = self.config.patch_size
452452
post_patch_num_frames = num_frames // p_t
453453
post_patch_height = height // p_h
454454
post_patch_width = width // p_w
455455

456456

457+
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
457458
rotary_emb = self.rope(hidden_states)
458459
hidden_states = self.patch_embedding(hidden_states)
459460
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
@@ -472,9 +473,9 @@ def __call__(
472473
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
473474
hidden_states = self.proj_out(hidden_states)
474475

475-
# TODO - can this reshape happen in a single command?
476476
hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1)
477-
hidden_states = hidden_states.reshape(batch_size, num_frames, height, width, num_channels)
478-
# jax.debug.print("FINAL hidden_states min: {x}", x=hidden_states.min())
479-
# jax.debug.print("FINAL hidden_states max: {x}", x=hidden_states.max())
477+
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
478+
hidden_states = jax.lax.collapse(hidden_states, 6, None)
479+
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
480+
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
480481
return hidden_states

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ def __init__(
109109
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
110110
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
111111

112+
self.jitted_decode = jax.jit(
113+
partial(
114+
self.vae.decode,
115+
feat_cache=self.vae_cache,
116+
return_dict=False
117+
)
118+
)
119+
120+
self.p_run_inference = None
121+
112122
@classmethod
113123
def load_text_encoder(cls, config: HyperParameters):
114124
text_encoder = UMT5EncoderModel.from_pretrained(
@@ -184,20 +194,27 @@ def load_scheduler(cls, config):
184194
return scheduler, scheduler_state
185195

186196
@classmethod
187-
def from_pretrained(cls, config: HyperParameters):
197+
def from_pretrained(cls, config: HyperParameters, vae_only=False):
188198
devices_array = max_utils.create_device_mesh(config)
189199
mesh = Mesh(devices_array, config.mesh_axes)
190200
rng = jax.random.key(config.seed)
191201
rngs = nnx.Rngs(rng)
202+
transformer=None
203+
tokenizer=None
204+
scheduler=None
205+
scheduler_state=None
206+
text_encoder=None
207+
if not vae_only:
208+
with mesh:
209+
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
210+
211+
text_encoder = cls.load_text_encoder(config=config)
212+
tokenizer = cls.load_tokenizer(config=config)
213+
214+
scheduler, scheduler_state = cls.load_scheduler(config=config)
192215

193216
with mesh:
194217
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
195-
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
196-
197-
text_encoder = cls.load_text_encoder(config=config)
198-
tokenizer = cls.load_tokenizer(config=config)
199-
200-
scheduler, scheduler_state = cls.load_scheduler(config=config)
201218

202219
return WanPipeline(
203220
tokenizer=tokenizer,
@@ -291,10 +308,10 @@ def prepare_latents(
291308
num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
292309
shape = (
293310
batch_size,
311+
num_channels_latents,
294312
num_latent_frames,
295313
int(height) // vae_scale_factor_spatial,
296314
int(width) // vae_scale_factor_spatial,
297-
num_channels_latents
298315
)
299316
latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype)
300317

@@ -370,6 +387,7 @@ def __call__(
370387
scheduler=self.scheduler,
371388
scheduler_state=scheduler_state
372389
)
390+
373391
with self.mesh:
374392
latents = p_run_inference(
375393
graphdef=graphdef,
@@ -379,21 +397,13 @@ def __call__(
379397
prompt_embeds=prompt_embeds,
380398
negative_prompt_embeds=negative_prompt_embeds
381399
)
382-
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
383-
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
400+
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
401+
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
384402
latents = latents / latents_std + latents_mean
385-
386403
latents = latents.astype(self.config.weights_dtype)
387404

388-
jitted_decode = jax.jit(
389-
partial(
390-
self.vae.decode,
391-
feat_cache=self.vae_cache,
392-
return_dict=False
393-
)
394-
)
395405
with self.mesh:
396-
video = jitted_decode(latents)[0]
406+
video = self.jitted_decode(latents)[0]
397407
video = jnp.transpose(video, (0, 4, 1, 2, 3))
398408
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
399409
video = self.video_processor.postprocess_video(video, output_type="np")
@@ -431,6 +441,5 @@ def run_inference(
431441
if do_classifier_free_guidance:
432442
noise_uncond = transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, negative_prompt_embeds)
433443
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
434-
435444
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
436445
return latents

0 commit comments

Comments
 (0)