Skip to content

Commit 56f5225

Browse files
support bs > 1. Issue where all gens except for 1st coming out bad.
1 parent d64e521 commit 56f5225

5 files changed

Lines changed: 68 additions & 61 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ split_head_dim: True
5353
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5454

5555
flash_block_sizes: {
56-
"block_q" : 512,
57-
"block_kv_compute" : 512,
58-
"block_kv" : 512,
59-
"block_q_dkv" : 512,
60-
"block_kv_dkv" : 512,
61-
"block_kv_dkv_compute" : 512,
62-
"block_q_dq" : 512,
63-
"block_kv_dq" : 512
56+
"block_q" : 1024,
57+
"block_kv_compute" : 1024,
58+
"block_kv" : 1024,
59+
"block_q_dkv" : 1024,
60+
"block_kv_dkv" : 1024,
61+
"block_kv_dkv_compute" : 1024,
62+
"block_q_dq" : 1024,
63+
"block_kv_dq" : 1024
6464
}
6565
# GroupNorm groups
6666
norm_num_groups: 32

src/maxdiffusion/generate_wan.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def run(config):
2424
pipeline = WanPipeline.from_pretrained(config)
2525
s0 = time.perf_counter()
26-
video = pipeline(
26+
videos = pipeline(
2727
prompt=config.prompt,
2828
negative_prompt=config.negative_prompt,
2929
height=config.height,
@@ -34,10 +34,11 @@ def run(config):
3434
)
3535

3636
print("compile time: ", (time.perf_counter() - s0))
37-
export_to_video(video[0], "jax_output.mp4", fps=16)
37+
for i in range(len(videos)):
38+
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
3839
s0 = time.perf_counter()
3940
with jax.profiler.trace("/tmp/trace/"):
40-
video = pipeline(
41+
videos = pipeline(
4142
prompt=config.prompt,
4243
negative_prompt=config.negative_prompt,
4344
height=config.height,
@@ -47,7 +48,8 @@ def run(config):
4748
guidance_scale=config.guidance_scale,
4849
)
4950
print("generation time: ", (time.perf_counter() - s0))
50-
export_to_video(video[0], "jax_output.mp4", fps=16)
51+
for i in range(len(videos)):
52+
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
5153

5254

5355
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,12 +1131,17 @@ def _decode(
11311131
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
11321132
# Most likely due to an incorrect reshaping in the decoder.
11331133
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1134-
if len(fm1.shape) == 4:
1135-
fm1 = jnp.expand_dims(fm1, axis=0)
1136-
fm2 = jnp.expand_dims(fm2, axis=0)
1137-
fm3 = jnp.expand_dims(fm3, axis=0)
1138-
fm4 = jnp.expand_dims(fm4, axis=0)
1134+
# When batch_size is 0, expand batch dim for contatenation
1135+
# else, expand frame dim for concatenation so that batch dim stays intact.
1136+
axis=0
1137+
if fm1.shape[0] > 1:
1138+
axis=1
11391139

1140+
if len(fm1.shape) == 4:
1141+
fm1 = jnp.expand_dims(fm1, axis=axis)
1142+
fm2 = jnp.expand_dims(fm2, axis=axis)
1143+
fm3 = jnp.expand_dims(fm3, axis=axis)
1144+
fm4 = jnp.expand_dims(fm4, axis=axis)
11401145
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
11411146
out = jnp.clip(out, min=-1.0, max=1.0)
11421147
feat_cache.clear_cache()

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def __init__(
398398
dtype=dtype,
399399
param_dtype=weights_dtype,
400400
precision=precision,
401-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)),
401+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None, "conv_out",)),
402402
)
403403

404404
# 2. Condition embeddings

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import numpy as np
1818
import jax
1919
import jax.numpy as jnp
20-
from jax.sharding import Mesh, PositionalSharding
20+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2121
import flax
2222
import flax.linen as nn
2323
from flax import nnx
2424
from ...pyconfig import HyperParameters
2525
from ... import max_logging
2626
from ... import max_utils
27+
from ...max_utils import get_flash_block_sizes, get_precision
2728
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
2829
from ...models.wan.transformers.transformer_wan import WanModel
2930
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
@@ -59,11 +60,12 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
5960

6061
partial(nnx.jit, static_argnums=(3,))
6162
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
62-
# breakpoint()
63+
6364
def create_model(rngs: nnx.Rngs, wan_config: dict):
6465
wan_transformer = WanModel(**wan_config, rngs=rngs)
6566
return wan_transformer
6667

68+
# 1. Load config.
6769
wan_config = WanModel.load_config(
6870
config.pretrained_model_name_or_path,
6971
subfolder="transformer"
@@ -72,32 +74,39 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7274
wan_config["dtype"] = config.activations_dtype
7375
wan_config["weights_dtype"] = config.weights_dtype
7476
wan_config["attention"] = config.attention
77+
wan_config["precision"] = get_precision(config)
78+
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
79+
80+
# 2. eval_shape - will not use flops or create weights on device
81+
# thus not using HBM memory.
7582
p_model_factory = partial(create_model, wan_config=wan_config)
7683
wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs)
7784
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
78-
#breakpoint()
85+
86+
# 3. retrieve the state shardings, mapping logical names to mesh axis names.
7987
logical_state_spec = nnx.get_partition_spec(state)
8088
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
8189
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
8290
params = state.to_pure_dict()
8391
state = dict(nnx.to_flat_state(state))
84-
# del state
92+
93+
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
94+
# This helps with loading sharded weights directly into the accelerators without fist copying them
95+
# all to one device and then distributing them, thus using low HBM memory.
8596
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
8697
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
8798
for path, val in flax.traverse_util.flatten_dict(params).items():
8899
sharding = logical_state_sharding[path].value
89-
state[path].value = jax.device_put(val, sharding)
100+
try:
101+
state[path].value = jax.device_put(val, sharding)
102+
except:
103+
breakpoint()
90104
state = nnx.from_flat_state(state)
91-
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules)
92-
state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
93-
pspecs = nnx.get_partition_spec(state)
94-
#breakpoint()
95-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
96-
#breakpoint()
97-
#wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state)
98-
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
105+
106+
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
99107
return wan_transformer
100108

109+
101110
partial(nnx.jit, static_argnums=(1,))
102111
def create_sharded_logical_model(model, logical_axis_rules):
103112
graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...)
@@ -108,6 +117,7 @@ def create_sharded_logical_model(model, logical_axis_rules):
108117
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
109118
return wan_transformer
110119

120+
111121
class WanPipeline:
112122
r"""
113123
Pipeline for text-to-video generation using Wan.
@@ -155,6 +165,7 @@ def __init__(
155165

156166
self.p_run_inference = None
157167

168+
158169
@classmethod
159170
def load_text_encoder(cls, config: HyperParameters):
160171
text_encoder = UMT5EncoderModel.from_pretrained(
@@ -163,6 +174,7 @@ def load_text_encoder(cls, config: HyperParameters):
163174
)
164175
return text_encoder
165176

177+
166178
@classmethod
167179
def load_tokenizer(cls, config: HyperParameters):
168180
tokenizer = AutoTokenizer.from_pretrained(
@@ -171,6 +183,7 @@ def load_tokenizer(cls, config: HyperParameters):
171183
)
172184
return tokenizer
173185

186+
174187
@classmethod
175188
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
176189
wan_vae = AutoencoderKLWan.from_config(
@@ -196,33 +209,14 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
196209
wan_vae = p_create_sharded_logical_model(model=wan_vae)
197210
return wan_vae, vae_cache
198211

212+
199213
@classmethod
200214
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
201215
with mesh:
202216
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
203-
# wan_transformer = WanModel.from_config(
204-
# config.pretrained_model_name_or_path,
205-
# subfolder="transformer",
206-
# rngs=rngs,
207-
# attention=config.attention,
208-
# mesh=mesh,
209-
# dtype=config.activations_dtype,
210-
# weights_dtype=config.weights_dtype
211-
# )
212-
# graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
213-
# breakpoint()
214-
# params = state.to_pure_dict()
215-
# del state
216-
# #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
217-
# params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
218-
# #params = jax.device_put(params, PositionalSharding(devices_array).replicate())
219-
# wan_transformer = nnx.merge(graphdef, params, rest_of_state)
220-
# # Shard
221-
# p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
222-
# with mesh:
223-
# wan_transformer = p_create_sharded_logical_model(model=wan_transformer)
224217
return wan_transformer
225218

219+
226220
@classmethod
227221
def load_scheduler(cls, config):
228222
scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained(
@@ -232,6 +226,7 @@ def load_scheduler(cls, config):
232226
)
233227
return scheduler, scheduler_state
234228

229+
235230
@classmethod
236231
def from_pretrained(cls, config: HyperParameters, vae_only=False):
237232
devices_array = max_utils.create_device_mesh(config)
@@ -268,6 +263,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False):
268263
config=config
269264
)
270265

266+
271267
def _get_t5_prompt_embeds(
272268
self,
273269
prompt: Union[str, List[str]] = None,
@@ -302,6 +298,7 @@ def _get_t5_prompt_embeds(
302298

303299
return prompt_embeds
304300

301+
305302
def encode_prompt(
306303
self,
307304
prompt: Union[str, List[str]],
@@ -333,6 +330,7 @@ def encode_prompt(
333330

334331
return prompt_embeds, negative_prompt_embeds
335332

333+
336334
def prepare_latents(
337335
self,
338336
batch_size: int,
@@ -356,6 +354,7 @@ def prepare_latents(
356354

357355
return latents
358356

357+
359358
def __call__(
360359
self,
361360
prompt: Union[str, List[str]] = None,
@@ -382,9 +381,9 @@ def __call__(
382381

383382
# 2. Define call parameters
384383
if prompt is not None and isinstance(prompt, str):
385-
batch_size = 1
386-
elif prompt is not None and isinstance(prompt, list):
387-
batch_size = len(prompt)
384+
prompt = [prompt]
385+
386+
batch_size = len(prompt)
388387

389388
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
390389
prompt=prompt,
@@ -406,12 +405,13 @@ def __call__(
406405
num_channels_latents=num_channel_latents
407406
)
408407

409-
prompt_embeds = jnp.concatenate([prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
410-
negative_prompt_embeds = jnp.concatenate([negative_prompt_embeds] * latents.shape[0], dtype=self.config.weights_dtype)
411-
412-
latents = jax.device_put(latents, PositionalSharding(self.devices_array).replicate())
413-
prompt_embeds = jax.device_put(prompt_embeds, PositionalSharding(self.devices_array).replicate())
414-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, PositionalSharding(self.devices_array).replicate())
408+
data_sharding = PositionalSharding(self.devices_array).replicate()
409+
if len(prompt) % jax.device_count() == 0:
410+
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
411+
412+
latents = jax.device_put(latents, data_sharding)
413+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
414+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
415415

416416
scheduler_state = self.scheduler.set_timesteps(
417417
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape

0 commit comments

Comments
 (0)