Skip to content

Commit b5d6dd7

Browse files
committed
more changes
1 parent fad8be1 commit b5d6dd7

6 files changed

Lines changed: 59 additions & 16 deletions

File tree

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ hf-transfer>=0.1.9
6767
hf-xet>=1.4.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
6868
httpcore>=1.0.9
6969
httpx>=0.28.1
70-
huggingface-hub>=0.36.2
70+
huggingface-hub>=1.10.1
7171
humanize>=4.15.0
7272
hypothesis>=6.142.1
7373
idna>=3.11

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ ftfy
1414
tensorboard>=2.17.0
1515
tensorboardx>=2.6.2.2
1616
tensorboard-plugin-profile>=2.15.2
17-
tokamax
1817
Jinja2
1918
scikit-image
2019
parameterized

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2828
from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2929
from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
30+
from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base
3031
from einops import rearrange
3132
from .. import common_types, max_logging
3233

@@ -363,7 +364,10 @@ def wrap_flash_attention(query, key, value):
363364
# Both are (kv_padded_len,) - element-wise multiplication
364365
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
365366

366-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
367+
if attention_kernel == "tokamax_ring":
368+
segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
369+
else:
370+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
367371

368372
# make_splash_mha is wrapped around shardmap and seq and head is already
369373
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
@@ -1954,4 +1958,4 @@ def setup(self):
19541958
def __call__(self, hidden_states, deterministic=True):
19551959
hidden_states = self.proj(hidden_states)
19561960
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1957-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1961+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,6 @@ def __init__(
759759
precision=precision,
760760
)
761761

762-
@nnx.jit(static_argnames="feat_idx")
763762
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
764763
if feat_cache is not None:
765764
idx = feat_idx
@@ -908,7 +907,6 @@ def __init__(
908907
precision=precision,
909908
)
910909

911-
@nnx.jit(static_argnames="feat_idx")
912910
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
913911
if feat_cache is not None:
914912
idx = feat_idx
@@ -1104,6 +1102,7 @@ def __init__(
11041102
)
11051103
self.mesh = mesh
11061104

1105+
@nnx.jit
11071106
def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11081107
feat_cache.init_cache()
11091108
if x.shape[-1] != 3:
@@ -1175,6 +1174,7 @@ def encode(
11751174
return (posterior,)
11761175
return FlaxAutoencoderKLOutput(latent_dist=posterior)
11771176

1177+
@nnx.jit
11781178
def _decode(
11791179
self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
11801180
) -> Union[FlaxDecoderOutput, jax.Array]:

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,33 @@ def load_base_wan_transformer(
212212
device = jax.local_devices(backend=device)[0]
213213
filename = "diffusion_pytorch_model.safetensors.index.json"
214214
local_files = False
215+
216+
# Only rank 0 downloads; others wait for cache to be populated
217+
process_index = jax.process_index()
215218
if os.path.isdir(pretrained_model_name_or_path):
216219
index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
217220
if not os.path.isfile(index_file_path):
218221
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
219222
local_files = True
220223
elif hf_download:
221-
# download the index file for sharded models.
222-
index_file_path = hf_hub_download(
223-
pretrained_model_name_or_path,
224-
subfolder=subfolder,
225-
filename=filename,
226-
)
224+
# Only rank 0 downloads; synchronize across all ranks
225+
if process_index == 0:
226+
# download the index file for sharded models.
227+
index_file_path = hf_hub_download(
228+
pretrained_model_name_or_path,
229+
subfolder=subfolder,
230+
filename=filename,
231+
)
232+
jax.experimental.multihost_utils.sync_global_devices("model_index_download")
233+
234+
if process_index != 0:
235+
# Non-rank-0 processes wait and use the cached path
236+
index_file_path = hf_hub_download(
237+
pretrained_model_name_or_path,
238+
subfolder=subfolder,
239+
filename=filename,
240+
force_download=False, # Use cache, don't download
241+
)
227242
with jax.default_device(device):
228243
# open the index file.
229244
with open(index_file_path, "r") as f:
@@ -238,7 +253,19 @@ def load_base_wan_transformer(
238253
if local_files:
239254
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
240255
else:
241-
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
256+
# Only rank 0 downloads new files; others use cached versions
257+
if process_index == 0:
258+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
259+
jax.experimental.multihost_utils.sync_global_devices(f"model_download_{model_file}")
260+
261+
if process_index != 0:
262+
# Non-rank-0: use cached version
263+
ckpt_shard_path = hf_hub_download(
264+
pretrained_model_name_or_path,
265+
subfolder=subfolder,
266+
filename=model_file,
267+
force_download=False, # Use cache
268+
)
242269
# now get all the filenames for the model that need downloading
243270
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
244271

@@ -304,12 +331,25 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
304331
device = jax.devices(device)[0]
305332
subfolder = "vae"
306333
filename = "diffusion_pytorch_model.safetensors"
334+
process_index = jax.process_index()
335+
307336
if os.path.isdir(pretrained_model_name_or_path):
308337
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
309338
if not os.path.isfile(ckpt_path):
310339
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
311340
elif hf_download:
312-
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
341+
# Only rank 0 downloads; others use cache
342+
if process_index == 0:
343+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
344+
jax.experimental.multihost_utils.sync_global_devices("vae_download")
345+
346+
if process_index != 0:
347+
ckpt_path = hf_hub_download(
348+
pretrained_model_name_or_path,
349+
subfolder=subfolder,
350+
filename=filename,
351+
force_download=False, # Use cache
352+
)
313353
max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}")
314354
with jax.default_device(device):
315355
if ckpt_path is not None:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def encode_prompt(
492492
num_videos_per_prompt=num_videos_per_prompt,
493493
max_sequence_length=max_sequence_length,
494494
)
495-
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
495+
prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
496496

497497
if negative_prompt_embeds is None:
498498
negative_prompt = negative_prompt or ""
@@ -502,7 +502,7 @@ def encode_prompt(
502502
num_videos_per_prompt=num_videos_per_prompt,
503503
max_sequence_length=max_sequence_length,
504504
)
505-
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
505+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32)
506506

507507
return prompt_embeds, negative_prompt_embeds
508508

0 commit comments

Comments
 (0)