Skip to content

Commit 026aff7

Browse files
committed
added debug statements
1 parent b9f3258 commit 026aff7

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,9 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
646646

647647
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
648648
if attn is not None:
649-
jax.debug.print("MidBlock before attn {i}: {shape}", shape=x.shape)
649+
jax.debug.print("MidBlock before attn {i}: {shape}", i=i, shape=x.shape)
650650
x = attn(x)
651-
jax.debug.print("MidBlock after attn {i}: {shape}", shape=x.shape)
651+
jax.debug.print("MidBlock after attn {i}: {shape}", i=i, shape=x.shape)
652652
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
653653
new_cache["resnets"].append(c)
654654

@@ -881,14 +881,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
881881
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
882882

883883
for i, layer in enumerate(self.down_blocks):
884-
jax.debug.print(f"Encoder before down_block {i} ({type(layer).__name__}): {{shape}}", shape=x.shape)
884+
jax.debug.print(f"Encoder before down_block {i} ({type(layer).__name__}): {{shape}}", i=i, shape=x.shape)
885885
if isinstance(layer, (WanResidualBlock, WanResample)):
886886
x, c = layer(x, current_down_caches[i])
887887
new_cache["down_blocks"].append(c)
888888
else:
889889
x = layer(x)
890890
new_cache["down_blocks"].append(None)
891-
jax.debug.print(f"Encoder after down_block {i}: {{shape}}", shape=x.shape)
891+
jax.debug.print(f"Encoder after down_block {i}: {{shape}}", i=i, shape=x.shape)
892892

893893
jax.debug.print("Encoder before mid_block: {shape}", shape=x.shape)
894894
x, c = self.mid_block(x, cache.get("mid_block"))

0 commit comments

Comments
 (0)