You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* fixes ssim.
* adds pusav1 video dataset.
* wip - adds trainer and attn changes.
* force splash attention for cross attention.
* use nnx.scan over for loop.
* support wan transformers for nnx.scan.
* fix ag from vmap/scan.
* linting.
* remove slg to simplify the code.
Copy file name to clipboardExpand all lines: src/maxdiffusion/models/modeling_flax_pytorch_utils.py
+12-3Lines changed: 12 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -25,6 +25,7 @@
25
25
fromcheximportArray
26
26
from ..utilsimportlogging
27
27
from .. importmax_logging
28
+
from .. importcommon_types
28
29
29
30
30
31
logger=logging.get_logger(__name__)
@@ -86,7 +87,7 @@ def rename_key(key):
86
87
87
88
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
88
89
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
0 commit comments