Skip to content

Commit b9f3258

Browse files
committed
added debug statements
1 parent 4cac799 commit b9f3258

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def __init__(
544544
)
545545

546546
def __call__(self, x: jax.Array):
547-
547+
jax.debug.print("AttentionBlock input shape: {shape}", shape=x.shape)
548548
identity = x
549549
batch_size, time, height, width, channels = x.shape
550550

@@ -638,13 +638,17 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
638638
if cache is None:
639639
cache = {}
640640
new_cache = {"resnets": []}
641+
jax.debug.print("MidBlock input shape: {shape}", shape=x.shape)
641642

642643
x, c = self.resnets[0](x, cache.get("resnets", [None])[0])
643644
new_cache["resnets"].append(c)
645+
jax.debug.print("MidBlock after resnets[0] shape: {shape}", shape=x.shape)
644646

645647
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
646648
if attn is not None:
649+
jax.debug.print("MidBlock before attn {i}: {shape}", shape=x.shape)
647650
x = attn(x)
651+
jax.debug.print("MidBlock after attn {i}: {shape}", shape=x.shape)
648652
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
649653
new_cache["resnets"].append(c)
650654

@@ -868,23 +872,28 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
868872
if cache is None:
869873
cache = {}
870874
new_cache = {}
871-
875+
jax.debug.print("Encoder input shape: {shape}", shape=x.shape)
872876
x, c = self.conv_in(x, cache.get("conv_in"))
873877
new_cache["conv_in"] = c
878+
jax.debug.print("Encoder after conv_in shape: {shape}", shape=x.shape)
874879

875880
new_cache["down_blocks"] = []
876881
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
877882

878883
for i, layer in enumerate(self.down_blocks):
884+
jax.debug.print(f"Encoder before down_block {i} ({type(layer).__name__}): {{shape}}", shape=x.shape)
879885
if isinstance(layer, (WanResidualBlock, WanResample)):
880886
x, c = layer(x, current_down_caches[i])
881887
new_cache["down_blocks"].append(c)
882888
else:
883889
x = layer(x)
884890
new_cache["down_blocks"].append(None)
891+
jax.debug.print(f"Encoder after down_block {i}: {{shape}}", shape=x.shape)
885892

893+
jax.debug.print("Encoder before mid_block: {shape}", shape=x.shape)
886894
x, c = self.mid_block(x, cache.get("mid_block"))
887895
new_cache["mid_block"] = c
896+
jax.debug.print("Encoder after mid_block: {shape}", shape=x.shape)
888897

889898
x = self.norm_out(x)
890899
x = self.nonlinearity(x)
@@ -1193,7 +1202,7 @@ def __init__(
11931202
precision=precision,
11941203
)
11951204

1196-
# @nnx.jit
1205+
@nnx.jit
11971206
def encode(
11981207
self, x: jax.Array, return_dict: bool = True
11991208
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:

0 commit comments

Comments
 (0)