Skip to content

Commit c777b07

Browse files
committed
passed img related params to WanTransformerBlock call
1 parent 40d3956 commit c777b07

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ def init_block(rngs):
515515
precision=precision,
516516
attention=attention,
517517
dropout=dropout,
518+
added_kv_proj_dim=added_kv_proj_dim,
519+
image_seq_len=pos_embed_seq_len,
518520
)
519521

520522
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -543,6 +545,8 @@ def init_block(rngs):
543545
precision=precision,
544546
attention=attention,
545547
enable_jax_named_scopes=enable_jax_named_scopes,
548+
added_kv_proj_dim=added_kv_proj_dim,
549+
image_seq_len=pos_embed_seq_len,
546550
)
547551
blocks.append(block)
548552
self.blocks = blocks

0 commit comments

Comments
 (0)