Skip to content

Commit e154cf2

Browse files
committed
added debug statements
1 parent de0cbbb commit e154cf2

1 file changed

Lines changed: 62 additions & 18 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999

100100
def initialize_cache(self, batch_size, height, width, dtype):
101101
cache = jnp.zeros(
102-
(batch_size, CACHE_T, height, width, self.conv.in_features), dtype=dtype
102+
(batch_size, CACHE_T, height, width, self.conv.in_features), dtype=jnp.bfloat16
103103
)
104104

105105
# OPTIMIZATION: Spatial Partitioning on Initialization
@@ -139,6 +139,8 @@ def __call__(
139139
current_padding = list(self._causal_padding)
140140

141141
if cache_x is not None:
142+
if cache_x.dtype != x.dtype:
143+
cache_x = cache_x.astype(x.dtype)
142144
x_concat = jnp.concatenate([cache_x, x], axis=1)
143145
new_cache = x_concat[:, -CACHE_T:, ...]
144146

@@ -162,6 +164,8 @@ def __call__(
162164
x_padded = x_input
163165

164166
out = self.conv(x_padded)
167+
new_cache = new_cache.astype(jnp.bfloat16)
168+
print(f"Exiting WanCausalConv3d: {out.shape}")
165169
return out, new_cache
166170

167171

@@ -189,11 +193,13 @@ def __init__(
189193
self.bias = 0
190194

191195
def __call__(self, x: jax.Array) -> jax.Array:
196+
print(f"Entering WanRMS_norm: {x.shape}")
192197
normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True)
193198
normalized = x / jnp.maximum(normalized, self.eps)
194199
normalized = normalized * self.scale * self.gamma
195200
if self.bias:
196201
return normalized + self.bias.value
202+
print(f"Exiting WanRMS_norm: {normalized.shape}")
197203
return normalized
198204

199205

@@ -206,18 +212,23 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
206212
self.method = method
207213

208214
def __call__(self, x: jax.Array) -> jax.Array:
215+
print(f"Entering WanUpsample: {x.shape}")
209216
input_dtype = x.dtype
210217
in_shape = x.shape
211218
assert len(in_shape) == 4, "This module only takes tensors with shape of 4."
212219
n, h, w, c = in_shape
213220
target_h = int(h * self.scale_factor[0])
214221
target_w = int(w * self.scale_factor[1])
215222
out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method)
216-
return out.astype(input_dtype)
223+
out = out.astype(input_dtype)
224+
print(f"Exiting WanUpsample: {out.shape}")
225+
return out
217226

218227

219228
class Identity(nnx.Module):
220229
def __call__(self, x, cache=None):
230+
print(f"Entering Identity: {x.shape}")
231+
print(f"Exiting Identity: {x.shape}")
221232
return x, cache
222233

223234

@@ -252,7 +263,10 @@ def __init__(
252263
)
253264

254265
def __call__(self, x, cache=None):
255-
return self.conv(x), cache
266+
print(f"Entering ZeroPaddedConv2D: {x.shape}")
267+
out = self.conv(x)
268+
print(f"Exiting ZeroPaddedConv2D: {out.shape}")
269+
return out, cache
256270

257271

258272
class WanResample(nnx.Module):
@@ -363,46 +377,59 @@ def initialize_cache(self, batch_size, height, width, dtype):
363377
def __call__(
364378
self, x: jax.Array, cache: Dict[str, Any] = None
365379
) -> Tuple[jax.Array, Dict[str, Any]]:
380+
print(f"Entering WanResample ({self.mode}): {x.shape}")
366381
if cache is None:
367382
cache = {}
368383
new_cache = {}
369384

370385
if self.mode == "upsample2d":
371386
b, t, h, w, c = x.shape
372387
x = x.reshape(b * t, h, w, c)
388+
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
373389
x = self.resample(x)
390+
print(f"WanResample ({self.mode}) after resample: {x.shape}")
374391
h_new, w_new, c_new = x.shape[1:]
375392
x = x.reshape(b, t, h_new, w_new, c_new)
376393

377394
elif self.mode == "upsample3d":
378395
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
379396
new_cache["time_conv"] = tc_cache
397+
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
380398

381399
b, t, h, w, c = x.shape
382400
x = x.reshape(b, t, h, w, 2, c // 2)
383401
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
384402
x = x.reshape(b, t * 2, h, w, c // 2)
403+
print(f"WanResample ({self.mode}) after time dim expand: {x.shape}")
404+
385405

386406
b, t, h, w, c = x.shape
387407
x = x.reshape(b * t, h, w, c)
408+
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
388409
x = self.resample(x)
410+
print(f"WanResample ({self.mode}) after resample: {x.shape}")
389411
h_new, w_new, c_new = x.shape[1:]
390412
x = x.reshape(b, t, h_new, w_new, c_new)
391413

392414
elif self.mode == "downsample2d":
393415
b, t, h, w, c = x.shape
394416
x = x.reshape(b * t, h, w, c)
417+
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
395418
x, _ = self.resample(x, None)
419+
print(f"WanResample ({self.mode}) after resample: {x.shape}")
396420
h_new, w_new, c_new = x.shape[1:]
397421
x = x.reshape(b, t, h_new, w_new, c_new)
398422

399423
elif self.mode == "downsample3d":
400424
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
401425
new_cache["time_conv"] = tc_cache
426+
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
402427

403428
b, t, h, w, c = x.shape
404429
x = x.reshape(b * t, h, w, c)
430+
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
405431
x, _ = self.resample(x, None)
432+
print(f"WanResample ({self.mode}) after resample: {x.shape}")
406433
h_new, w_new, c_new = x.shape[1:]
407434
x = x.reshape(b, t, h_new, w_new, c_new)
408435

@@ -413,6 +440,7 @@ def __call__(
413440
else:
414441
x = self.resample(x)
415442

443+
print(f"Exiting WanResample ({self.mode}): {x.shape}")
416444
return x, new_cache
417445

418446

@@ -486,26 +514,33 @@ def initialize_cache(self, batch_size, height, width, dtype):
486514
return cache
487515

488516
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
517+
print(f"Entering WanResidualBlock (in={self.conv1.conv.in_features}, out={self.conv1.conv.out_features}): {x.shape}")
489518
if cache is None:
490519
cache = {}
491520
new_cache = {}
492521

493522
h, sc_cache = self.conv_shortcut(x, cache.get("shortcut"))
494523
new_cache["shortcut"] = sc_cache
524+
print(f"WanResidualBlock after shortcut: {h.shape}")
495525

496526
x = self.norm1(x)
497527
x = self.nonlinearity(x)
528+
print(f"WanResidualBlock after norm1/nl: {x.shape}")
498529

499530
x, c1 = self.conv1(x, cache.get("conv1"))
500531
new_cache["conv1"] = c1
532+
print(f"WanResidualBlock after conv1: {x.shape}")
501533

502534
x = self.norm2(x)
503535
x = self.nonlinearity(x)
536+
print(f"WanResidualBlock after norm2/nl: {x.shape}")
504537

505538
x, c2 = self.conv2(x, cache.get("conv2"))
506539
new_cache["conv2"] = c2
540+
print(f"WanResidualBlock after conv2: {x.shape}")
507541

508542
x = x + h
543+
print(f"Exiting WanResidualBlock: {x.shape}")
509544
return x, new_cache
510545

511546

@@ -544,15 +579,17 @@ def __init__(
544579
)
545580

546581
def __call__(self, x: jax.Array):
547-
jax.debug.print("AttentionBlock input shape: {shape}", shape=x.shape)
582+
print(f"Entering WanAttentionBlock: {x.shape}")
548583
identity = x
549584
batch_size, time, height, width, channels = x.shape
550585

551586
x = x.reshape(batch_size * time, height, width, channels)
587+
print(f"WanAttentionBlock reshaped for norm: {x.shape}")
552588
x = self.norm(x)
553589

554590
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
555591
# qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
592+
print(f"WanAttentionBlock qkv shape: {qkv.shape}")
556593
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
557594
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
558595
q, k, v = jnp.split(qkv, 3, axis=-2)
@@ -566,8 +603,9 @@ def __call__(self, x: jax.Array):
566603
x = self.proj(x)
567604
# Reshape back
568605
x = x.reshape(batch_size, time, height, width, channels)
569-
570-
return x + identity
606+
out = x + identity
607+
print(f"Exiting WanAttentionBlock: {out.shape}")
608+
return out
571609

572610

573611
class WanMidBlock(nnx.Module):
@@ -635,23 +673,26 @@ def initialize_cache(self, batch_size, height, width, dtype):
635673
return cache
636674

637675
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
676+
print(f"Entering WanMidBlock: {x.shape}")
638677
if cache is None:
639678
cache = {}
640679
new_cache = {"resnets": []}
641-
jax.debug.print("MidBlock input shape: {shape}", shape=x.shape)
642680

643681
x, c = self.resnets[0](x, cache.get("resnets", [None])[0])
644682
new_cache["resnets"].append(c)
645-
jax.debug.print("MidBlock after resnets[0] shape: {shape}", shape=x.shape)
683+
print(f"WanMidBlock after resnets[0]: {x.shape}")
646684

647685
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
648686
if attn is not None:
649-
jax.debug.print("MidBlock before attn {i}: {shape}", i=i, shape=x.shape)
687+
print(f"WanMidBlock before attn {i}: {x.shape}")
650688
x = attn(x)
651-
jax.debug.print("MidBlock after attn {i}: {shape}", i=i, shape=x.shape)
689+
print(f"WanMidBlock after attn {i}: {x.shape}")
690+
print(f"WanMidBlock before resnets[{i + 1}]: {x.shape}")
652691
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
653692
new_cache["resnets"].append(c)
693+
print(f"WanMidBlock after resnets[{i + 1}]: {x.shape}")
654694

695+
print(f"Exiting WanMidBlock: {x.shape}")
655696
return x, new_cache
656697

657698

@@ -869,38 +910,41 @@ def init_cache(self, batch_size, height, width, dtype):
869910
return cache
870911

871912
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
913+
print(f"Entering WanEncoder3d: {x.shape}")
872914
if cache is None:
873915
cache = {}
874916
new_cache = {}
875-
jax.debug.print("Encoder input shape: {shape}", shape=x.shape)
917+
876918
x, c = self.conv_in(x, cache.get("conv_in"))
877919
new_cache["conv_in"] = c
878-
jax.debug.print("Encoder after conv_in shape: {shape}", shape=x.shape)
920+
print(f"WanEncoder3d after conv_in: {x.shape}")
879921

880922
new_cache["down_blocks"] = []
881923
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
882924

883925
for i, layer in enumerate(self.down_blocks):
884-
jax.debug.print("Encoder before down_block {i} (" + type(layer).__name__ + "): {shape}", i=i, shape=x.shape)
926+
print(f"WanEncoder3d before down_block {i} ({type(layer).__name__}): {x.shape}")
885927
if isinstance(layer, (WanResidualBlock, WanResample)):
886928
x, c = layer(x, current_down_caches[i])
887929
new_cache["down_blocks"].append(c)
888930
else:
889931
x = layer(x)
890932
new_cache["down_blocks"].append(None)
891-
jax.debug.print("Encoder after down_block {i} (" + type(layer).__name__ + "): {shape}", i=i, shape=x.shape)
933+
print(f"WanEncoder3d after down_block {i}: {x.shape}")
892934

893935

894-
jax.debug.print("Encoder before mid_block: {shape}", shape=x.shape)
895936
x, c = self.mid_block(x, cache.get("mid_block"))
896937
new_cache["mid_block"] = c
897-
jax.debug.print("Encoder after mid_block: {shape}", shape=x.shape)
938+
print(f"WanEncoder3d after mid_block: {x.shape}")
898939

899940
x = self.norm_out(x)
941+
print(f"WanEncoder3d after norm_out: {x.shape}")
900942
x = self.nonlinearity(x)
943+
print(f"WanEncoder3d after nonlinearity: {x.shape}")
901944

902945
x, c = self.conv_out(x, cache.get("conv_out"))
903946
new_cache["conv_out"] = c
947+
print(f"Exiting WanEncoder3d: {x.shape}")
904948

905949
return x, new_cache
906950

@@ -1203,7 +1247,7 @@ def __init__(
12031247
precision=precision,
12041248
)
12051249

1206-
@nnx.jit
1250+
# @nnx.jit
12071251
def encode(
12081252
self, x: jax.Array, return_dict: bool = True
12091253
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
@@ -1270,7 +1314,7 @@ def scan_fn_chunk(carry, input_slice):
12701314
return (posterior,)
12711315
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12721316

1273-
@nnx.jit
1317+
# @nnx.jit
12741318
def decode(
12751319
self, z: jax.Array, return_dict: bool = True
12761320
) -> Union[FlaxDecoderOutput, jax.Array]:

0 commit comments

Comments
 (0)