Skip to content

Commit de0cbbb

Browse files
committed
added debug statements
1 parent 026aff7 commit de0cbbb

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,14 +881,15 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
881881
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
882882

883883
for i, layer in enumerate(self.down_blocks):
884-
jax.debug.print(f"Encoder before down_block {i} ({type(layer).__name__}): {{shape}}", i=i, shape=x.shape)
884+
jax.debug.print("Encoder before down_block {i} (" + type(layer).__name__ + "): {shape}", i=i, shape=x.shape)
885885
if isinstance(layer, (WanResidualBlock, WanResample)):
886886
x, c = layer(x, current_down_caches[i])
887887
new_cache["down_blocks"].append(c)
888888
else:
889889
x = layer(x)
890890
new_cache["down_blocks"].append(None)
891-
jax.debug.print(f"Encoder after down_block {i}: {{shape}}", i=i, shape=x.shape)
891+
jax.debug.print("Encoder after down_block {i} (" + type(layer).__name__ + "): {shape}", i=i, shape=x.shape)
892+
892893

893894
jax.debug.print("Encoder before mid_block: {shape}", shape=x.shape)
894895
x, c = self.mid_block(x, cache.get("mid_block"))

0 commit comments

Comments
 (0)