Skip to content

Commit 9b07874

Browse files
authored
use local_devices instead of devices which defaults to first machine's devices. (#212)
1 parent e556ca1 commit 9b07874

2 files changed

Lines changed: 24 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key):
5757
return renamed_pt_key
5858

5959

60-
def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
61-
device = jax.devices(device)[0]
60+
def load_fusionx_transformer(
61+
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
62+
):
63+
device = jax.local_devices(backend=device)[0]
6264
with jax.default_device(device):
6365
if hf_download:
6466
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
@@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
9799
if flax_key in flax_state_dict:
98100
new_tensor = flax_state_dict[flax_key]
99101
else:
100-
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
102+
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
101103
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
102104
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
103105
validate_flax_state_dict(eval_shapes, flax_state_dict)
@@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
107109
return flax_state_dict
108110

109111

110-
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
111-
device = jax.devices(device)[0]
112+
def load_causvid_transformer(
113+
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
114+
):
115+
device = jax.local_devices(backend=device)[0]
112116
with jax.default_device(device):
113117
if hf_download:
114118
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
@@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
145149
if flax_key in flax_state_dict:
146150
new_tensor = flax_state_dict[flax_key]
147151
else:
148-
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
152+
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
149153
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
150154
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
151155
validate_flax_state_dict(eval_shapes, flax_state_dict)
@@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
155159
return flax_state_dict
156160

157161

158-
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
162+
def load_wan_transformer(
163+
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
164+
):
159165

160166
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
161-
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
167+
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
162168
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
163-
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
169+
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
164170
else:
165-
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
171+
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
166172

167173

168-
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
169-
device = jax.devices(device)[0]
174+
def load_base_wan_transformer(
175+
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
176+
):
177+
device = jax.local_devices(backend=device)[0]
170178
subfolder = "transformer"
171179
filename = "diffusion_pytorch_model.safetensors.index.json"
172180
local_files = False
@@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
237245
if flax_key in flax_state_dict:
238246
new_tensor = flax_state_dict[flax_key]
239247
else:
240-
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
248+
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
241249
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
242250
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
243251
validate_flax_state_dict(eval_shapes, flax_state_dict)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9595
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
9696
# This helps with loading sharded weights directly into the accelerators without fist copying them
9797
# all to one device and then distributing them, thus using low HBM memory.
98-
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
98+
params = load_wan_transformer(
99+
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
100+
)
99101
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100102
for path, val in flax.traverse_util.flatten_dict(params).items():
101103
sharding = logical_state_sharding[path].value

0 commit comments

Comments
 (0)