Skip to content

Commit 601f40c

Browse files
add missing conversions of pt to jax weights.
1 parent 93a3bb6 commit 601f40c

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
5555
("to_k", "key"),
5656
("to_v", "value"),
5757
("to_q", "query"),
58+
("txt_attn_proj", "txt_attn_proj"),
59+
("img_attn_proj", "img_attn_proj"),
60+
("txt_attn_qkv", "txt_attn_qkv"),
61+
("img_attn_qkv", "img_attn_qkv"),
5862
):
5963
if pt_tuple_key[-2] == rename_from:
6064
weight_name = pt_tuple_key[-1]

0 commit comments

Comments
 (0)