Skip to content

Commit fcb1ab1

Browse files
update unit tests.
1 parent a585a75 commit fcb1ab1

11 files changed

Lines changed: 375 additions & 368 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,18 @@ from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757

58-
#flash_block_sizes: {}
59-
flash_block_sizes: {
60-
"block_q" : 3024,
61-
"block_kv_compute" : 1024,
62-
"block_kv" : 2048,
63-
"block_q_dkv" : 3024,
64-
"block_kv_dkv" : 2048,
65-
"block_kv_dkv_compute" : 2048,
66-
"block_q_dq" : 3024,
67-
"block_kv_dq" : 2048
68-
}
58+
flash_block_sizes: {}
59+
# Use on v6e
60+
# flash_block_sizes: {
61+
# "block_q" : 3024,
62+
# "block_kv_compute" : 1024,
63+
# "block_kv" : 2048,
64+
# "block_q_dkv" : 3024,
65+
# "block_kv_dkv" : 2048,
66+
# "block_kv_dkv_compute" : 2048,
67+
# "block_q_dq" : 3024,
68+
# "block_kv_dq" : 2048
69+
# }
6970
# GroupNorm groups
7071
norm_num_groups: 32
7172

src/maxdiffusion/max_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,13 @@ def create_device_mesh(config, devices=None, logging=True):
281281
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
282282
if multi_slice_env:
283283
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
284-
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes)
284+
mesh = mesh_utils.create_hybrid_device_mesh(
285+
ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes
286+
)
285287
else:
286-
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes)
288+
mesh = mesh_utils.create_device_mesh(
289+
ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes
290+
)
287291

288292
if logging:
289293
max_logging.log(f"Decided on mesh: {mesh}")

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,8 +568,8 @@ class AttentionOp(nn.Module):
568568
use_memory_efficient_attention: bool = False
569569
split_head_dim: bool = False
570570
float32_qk_product: bool = True
571-
axis_names_q: AxisNames = ((BATCH, HEAD, LENGTH, D_KV),)
572-
axis_names_kv: AxisNames = ((BATCH, HEAD, KV_LENGTH, D_KV),)
571+
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
572+
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV)
573573
flash_min_seq_length: int = 4096
574574
flash_block_sizes: BlockSizes = None
575575
dtype: DType = jnp.float32

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
137137
else:
138138
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
139139

140+
140141
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
141142
device = jax.devices(device)[0]
142-
subfolder="transformer"
143-
filename="diffusion_pytorch_model.safetensors.index.json"
143+
subfolder = "transformer"
144+
filename = "diffusion_pytorch_model.safetensors.index.json"
144145
local_files = False
145146
if os.path.isdir(pretrained_model_name_or_path):
146147
index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
@@ -150,72 +151,72 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
150151
elif hf_download:
151152
# download the index file for sharded models.
152153
index_file_path = hf_hub_download(
153-
pretrained_model_name_or_path, subfolder=subfolder, filename=filename,
154+
pretrained_model_name_or_path,
155+
subfolder=subfolder,
156+
filename=filename,
154157
)
155-
with jax.default_device(device):
156-
# open the index file.
157-
with open(index_file_path, "r") as f:
158-
index_dict = json.load(f)
159-
model_files = set()
160-
for key in index_dict["weight_map"].keys():
161-
model_files.add(index_dict["weight_map"][key])
162-
163-
model_files = list(model_files)
164-
tensors = {}
165-
for model_file in model_files:
166-
if local_files:
167-
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
168-
else:
169-
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
170-
# now get all the filenames for the model that need downloading
171-
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
172-
173-
if ckpt_shard_path is not None:
174-
with safe_open(ckpt_shard_path, framework="pt") as f:
175-
for k in f.keys():
176-
tensors[k] = torch2jax(f.get_tensor(k))
177-
flax_state_dict = {}
178-
cpu = jax.local_devices(backend="cpu")[0]
179-
flattened_dict = flatten_dict(eval_shapes)
180-
# turn all block numbers to strings just for matching weights.
181-
# Later they will be turned back to ints.
182-
random_flax_state_dict = {}
183-
for key in flattened_dict:
184-
string_tuple = tuple([str(item) for item in key])
185-
random_flax_state_dict[string_tuple] = flattened_dict[key]
186-
del flattened_dict
187-
for pt_key, tensor in tensors.items():
188-
renamed_pt_key = rename_key(pt_key)
189-
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
190-
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
191-
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
192-
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
193-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
194-
pt_tuple_key = tuple(renamed_pt_key.split("."))
195-
196-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
197-
flax_key = rename_for_nnx(flax_key)
198-
flax_key = _tuple_str_to_int(flax_key)
199-
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
200-
validate_flax_state_dict(eval_shapes, flax_state_dict)
201-
flax_state_dict = unflatten_dict(flax_state_dict)
202-
del tensors
203-
jax.clear_caches()
204-
return flax_state_dict
158+
with jax.default_device(device):
159+
# open the index file.
160+
with open(index_file_path, "r") as f:
161+
index_dict = json.load(f)
162+
model_files = set()
163+
for key in index_dict["weight_map"].keys():
164+
model_files.add(index_dict["weight_map"][key])
165+
166+
model_files = list(model_files)
167+
tensors = {}
168+
for model_file in model_files:
169+
if local_files:
170+
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
171+
else:
172+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
173+
# now get all the filenames for the model that need downloading
174+
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
175+
176+
if ckpt_shard_path is not None:
177+
with safe_open(ckpt_shard_path, framework="pt") as f:
178+
for k in f.keys():
179+
tensors[k] = torch2jax(f.get_tensor(k))
180+
flax_state_dict = {}
181+
cpu = jax.local_devices(backend="cpu")[0]
182+
flattened_dict = flatten_dict(eval_shapes)
183+
# turn all block numbers to strings just for matching weights.
184+
# Later they will be turned back to ints.
185+
random_flax_state_dict = {}
186+
for key in flattened_dict:
187+
string_tuple = tuple([str(item) for item in key])
188+
random_flax_state_dict[string_tuple] = flattened_dict[key]
189+
del flattened_dict
190+
for pt_key, tensor in tensors.items():
191+
renamed_pt_key = rename_key(pt_key)
192+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
193+
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
194+
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
195+
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
196+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
197+
pt_tuple_key = tuple(renamed_pt_key.split("."))
198+
199+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
200+
flax_key = rename_for_nnx(flax_key)
201+
flax_key = _tuple_str_to_int(flax_key)
202+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
203+
validate_flax_state_dict(eval_shapes, flax_state_dict)
204+
flax_state_dict = unflatten_dict(flax_state_dict)
205+
del tensors
206+
jax.clear_caches()
207+
return flax_state_dict
205208

206209

207210
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
208211
device = jax.devices(device)[0]
209-
subfolder="vae"
210-
filename="diffusion_pytorch_model.safetensors"
212+
subfolder = "vae"
213+
filename = "diffusion_pytorch_model.safetensors"
211214
if os.path.isdir(pretrained_model_name_or_path):
212215
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
213216
if not os.path.isfile(ckpt_path):
214217
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
215218
elif hf_download:
216-
ckpt_path = hf_hub_download(
217-
pretrained_model_name_or_path, subfolder=subfolder, filename=filename
218-
)
219+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
219220
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
220221
with jax.default_device(device):
221222
if ckpt_path is not None:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def load_tokenizer(cls, config: HyperParameters):
183183

184184
@classmethod
185185
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
186-
186+
187187
def create_model(rngs: nnx.Rngs, config: HyperParameters):
188188
wan_vae = AutoencoderKLWan.from_config(
189189
config.pretrained_model_name_or_path,
@@ -194,11 +194,12 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
194194
weights_dtype=config.weights_dtype,
195195
)
196196
return wan_vae
197-
# 1. eval shape
197+
198+
# 1. eval shape
198199
p_model_factory = partial(create_model, config=config)
199200
wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs)
200201
graphdef, state = nnx.split(wan_vae, nnx.Param)
201-
202+
202203
# 2. retrieve the state shardings, mapping logical names to mesh axis names.
203204
logical_state_spec = nnx.get_partition_spec(state)
204205
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -215,7 +216,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
215216
sharding = logical_state_sharding[path].value
216217
state[path].value = device_put_replicated(val, sharding)
217218
state = nnx.from_flat_state(state)
218-
219+
219220
wan_vae = nnx.merge(graphdef, state)
220221
vae_cache = AutoencoderKLWanCache(wan_vae)
221222
return wan_vae, vae_cache
@@ -463,7 +464,18 @@ def __call__(
463464

464465

465466
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale"))
466-
def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask, do_classifier_free_guidance, guidance_scale):
467+
def transformer_forward_pass(
468+
graphdef,
469+
sharded_state,
470+
rest_of_state,
471+
latents,
472+
timestep,
473+
prompt_embeds,
474+
is_uncond,
475+
slg_mask,
476+
do_classifier_free_guidance,
477+
guidance_scale,
478+
):
467479
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
468480
noise_pred = wan_transformer(
469481
hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask
@@ -474,7 +486,7 @@ def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, ti
474486
noise_pred = noise_pred[:bsz]
475487
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
476488
latents = latents[:bsz]
477-
489+
478490
return noise_pred, latents
479491

480492

@@ -516,7 +528,7 @@ def run_inference(
516528
is_uncond=jnp.array(True, dtype=jnp.bool_),
517529
slg_mask=slg_mask,
518530
do_classifier_free_guidance=do_classifier_free_guidance,
519-
guidance_scale=guidance_scale
531+
guidance_scale=guidance_scale,
520532
)
521533

522534
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def string_to_bool(s: str) -> bool:
3636
return False
3737
raise ValueError(f"Can't convert {s} to bool")
3838

39+
3940
def string_to_list(string_list: str) -> list:
4041
return ast.literal_eval(string_list)
4142

43+
4244
_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool, list: string_to_list}
4345

4446
_config = None

src/maxdiffusion/tests/attention_test.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ..models.attention_flax import FlaxAttention
2424
from .. import max_utils
2525
from .. import pyconfig
26-
from maxdiffusion import FlaxUNet2DConditionModel
2726

2827
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
2928

@@ -73,54 +72,26 @@ def test_splash_attention(self):
7372
devices_array = max_utils.create_device_mesh(config)
7473
mesh = Mesh(devices_array, config.mesh_axes)
7574
flash_block_sizes = max_utils.get_flash_block_sizes(config)
76-
splash_attention = FlaxAttention(
77-
heads * head_depth,
78-
heads,
79-
head_depth,
80-
split_head_dim=True,
81-
attention_kernel="flash",
82-
mesh=mesh,
83-
dtype=jnp.bfloat16,
84-
flash_block_sizes=flash_block_sizes,
85-
)
86-
87-
params = splash_attention.init(key2, x)["params"]
88-
p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile()
89-
splash_attention_out = p_apply({"params": params}, x)
75+
with mesh:
76+
splash_attention = FlaxAttention(
77+
heads * head_depth,
78+
heads,
79+
head_depth,
80+
split_head_dim=True,
81+
attention_kernel="flash",
82+
mesh=mesh,
83+
dtype=jnp.bfloat16,
84+
flash_block_sizes=flash_block_sizes,
85+
)
86+
87+
params = splash_attention.init(key2, x)["params"]
88+
p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile()
89+
splash_attention_out = p_apply({"params": params}, x)
9090

9191
diff_norm = jnp.linalg.norm(dot_attention_out - splash_attention_out)
9292

9393
assert diff_norm < 1.0
9494

95-
def test_flash_block_sizes(self):
96-
"""Test loading flash block sizes from cli."""
97-
98-
pyconfig.initialize(
99-
[
100-
None,
101-
os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"),
102-
'flash_block_sizes={"block_q" : 256, "block_kv_compute": 256, "block_kv": 256,'
103-
'"block_q_dkv": 256, "block_kv_dkv": 256, "block_kv_dkv_compute": 256,'
104-
'"block_q_dq": 256, "block_kv_dq": 256}',
105-
"attention=flash",
106-
],
107-
unittest=True,
108-
)
109-
config = pyconfig.config
110-
devices_array = max_utils.create_device_mesh(config)
111-
mesh = Mesh(devices_array, config.mesh_axes)
112-
flash_block_sizes = max_utils.get_flash_block_sizes(config)
113-
_, _ = FlaxUNet2DConditionModel.from_pretrained(
114-
config.pretrained_model_name_or_path,
115-
revision=config.revision,
116-
subfolder="unet",
117-
dtype=jnp.bfloat16,
118-
from_pt=config.from_pt,
119-
attention_kernel=config.attention,
120-
flash_block_sizes=flash_block_sizes,
121-
mesh=mesh,
122-
)
123-
12495

12596
if __name__ == "__main__":
12697
absltest.main()

0 commit comments

Comments
 (0)