We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 40d3956 commit c777b07Copy full SHA for c777b07
1 file changed
src/maxdiffusion/models/wan/transformers/transformer_wan.py
@@ -515,6 +515,8 @@ def init_block(rngs):
515
precision=precision,
516
attention=attention,
517
dropout=dropout,
518
+ added_kv_proj_dim=added_kv_proj_dim,
519
+ image_seq_len=pos_embed_seq_len,
520
)
521
522
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -543,6 +545,8 @@ def init_block(rngs):
543
545
544
546
547
enable_jax_named_scopes=enable_jax_named_scopes,
548
549
550
551
blocks.append(block)
552
self.blocks = blocks
0 commit comments