Skip to content

Commit 02670b2

Browse files
committed
Fix chessboard issue
1 parent 92bb868 commit 02670b2

3 files changed

Lines changed: 5 additions & 8 deletions

File tree

src/maxdiffusion/generate_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ class DummyOut: pass
160160
params=upsample_params,
161161
prng_seed=generator,
162162
latents=latents,
163+
height=config.height,
164+
width=config.width,
165+
num_frames=config.num_frames,
163166
latents_normalized=False, # Upsampler operates on normalized latents; VAE decoder handles denorm internally
164167
adain_factor=getattr(config, "upsampler_adain_factor", 0.0),
165168
tone_map_compression_ratio=getattr(config, "upsampler_tone_map_compression_ratio", 0.0),

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,7 @@ def __call__(
907907
p_t = self.patch_size_t
908908

909909
hidden_states = sample.reshape(B, T // p_t, p_t, H // p, p, W // p, p, C)
910-
# 0:B, 1:T_p, 3:H_p, 5:W_p, 7:C, 2:p_t, 4:p_h, 6:p_w
911-
hidden_states = hidden_states.transpose(0, 1, 3, 5, 7, 2, 4, 6)
910+
hidden_states = hidden_states.transpose(0, 1, 3, 5, 7, 2, 6, 4)
912911
hidden_states = hidden_states.reshape(B, T // p_t, H // p, W // p, -1)
913912

914913
num_blocks = len(self.down_blocks) + 1
@@ -1108,8 +1107,7 @@ def __call__(
11081107
hidden_states = hidden_states.reshape(B, T, H, W, C_out_final, p_t, p, p)
11091108

11101109
# Pair H (2) with p_h (7) and W (3) with p_w (6)
1111-
# 0:B, 1:T, 5:p_t, 2:H, 6:p_h, 3:W, 7:p_w, 4:C_out_final
1112-
hidden_states = hidden_states.transpose(0, 1, 5, 2, 6, 3, 7, 4)
1110+
hidden_states = hidden_states.transpose(0, 1, 5, 2, 7, 3, 6, 4)
11131111
hidden_states = hidden_states.reshape(B, T * p_t, H * p, W * p, C_out_final)
11141112

11151113
return hidden_states

src/maxdiffusion/pipelines/ltx2/pipeline_ltx2_latent_upsample.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,8 @@ def __call__(
200200
latents_std = getattr(self.vae, "latents_std")
201201
latents = self._denormalize_latents(latents, latents_mean, latents_std, scaling_factor)
202202

203-
logging.info(f"[JAX Pipeline] Latents AFTER denorm (upsampler input): shape={latents.shape}, mean={jnp.mean(latents):.4f}, std={jnp.std(latents):.4f}, range=[{jnp.min(latents):.4f}, {jnp.max(latents):.4f}]")
204-
205203
# Run Latent Upsampler model — expects (batch, frames, height, width, channels)
206204
latents_upsampled = self.latent_upsampler.apply({'params': params['latent_upsampler']}, latents)
207-
logging.info(f"[JAX Pipeline] Latents AFTER upsampling: shape={latents_upsampled.shape}, mean={jnp.mean(latents_upsampled):.4f}, std={jnp.std(latents_upsampled):.4f}, range=[{jnp.min(latents_upsampled):.4f}, {jnp.max(latents_upsampled):.4f}]")
208205

209206
if adain_factor > 0.0:
210207
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
@@ -243,7 +240,6 @@ def __call__(
243240
# Cast latents to VAE dtype before decoding (matches main pipeline behavior)
244241
vae_dtype = getattr(self.vae, 'dtype', jnp.float32)
245242
latents = latents.astype(vae_dtype)
246-
logging.info(f"[Upsampler VAE decode] latents shape={latents.shape}, dtype={latents.dtype}")
247243

248244
# Decode latents to video
249245
if timestep is not None:

0 commit comments

Comments
 (0)