Skip to content

Commit d64e521

Browse files
reduces memory significantly when loading transformer. Needs clean up.
1 parent 5f2434d commit d64e521

2 files changed

Lines changed: 108 additions & 50 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,38 @@
3434

3535
BlockSizes = common_types.BlockSizes
3636

37+
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
38+
h_dim = w_dim = 2 * (attention_head_dim // 6)
39+
t_dim = attention_head_dim - h_dim - w_dim
40+
freqs = []
41+
for dim in [t_dim, h_dim, w_dim]:
42+
freq = get_1d_rotary_pos_embed(
43+
dim,
44+
max_seq_len,
45+
theta,
46+
freqs_dtype=jnp.float64,
47+
use_real=False
48+
)
49+
freqs.append(freq)
50+
freqs = jnp.concatenate(freqs, axis=1)
51+
# sizes = jnp.array([
52+
# attention_head_dim // 2 - 2 * (attention_head_dim // 6),
53+
# attention_head_dim // 6,
54+
# attention_head_dim // 6,
55+
# ])
56+
# cumulative_sizes = jnp.cumsum(jnp.array(sizes))
57+
# split_indices = cumulative_sizes[:-1]
58+
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)
59+
hw_size = attention_head_dim // 6
60+
61+
dims = [t_size, hw_size, hw_size]
62+
63+
# Calculate split indices as a static list of integers
64+
cumulative_sizes = np.cumsum(dims)
65+
split_indices = cumulative_sizes[:-1].tolist()
66+
freqs_split = jnp.split(freqs, split_indices, axis=1)
67+
return freqs_split
68+
3769
class WanRotaryPosEmbed(nnx.Module):
3870
def __init__(
3971
self,
@@ -45,44 +77,23 @@ def __init__(
4577
self.attention_head_dim = attention_head_dim
4678
self.patch_size = patch_size
4779
self.max_seq_len = max_seq_len
48-
49-
h_dim = w_dim = 2 * (attention_head_dim // 6)
50-
t_dim = attention_head_dim - h_dim - w_dim
51-
52-
freqs = []
53-
for dim in [t_dim, h_dim, w_dim]:
54-
freq = get_1d_rotary_pos_embed(
55-
dim,
56-
self.max_seq_len,
57-
theta,
58-
freqs_dtype=jnp.float64,
59-
use_real=False
60-
)
61-
freqs.append(freq)
62-
freqs = jnp.concatenate(freqs, axis=1)
63-
64-
sizes = [
65-
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
66-
self.attention_head_dim // 6,
67-
self.attention_head_dim // 6,
68-
]
69-
cumulative_sizes = jnp.cumsum(jnp.array(sizes))
70-
split_indices = cumulative_sizes[:-1]
71-
self.freqs_split = jnp.split(freqs, split_indices, axis=1)
80+
self.theta = theta
7281

7382
def __call__(self, hidden_states: jax.Array) -> jax.Array:
7483
_, num_frames, height, width, _ = hidden_states.shape
7584
p_t, p_h, p_w = self.patch_size
7685
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
7786

78-
freqs_f = jnp.expand_dims(jnp.expand_dims(self.freqs_split[0][:ppf], axis=1), axis=1)
79-
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, self.freqs_split[0].shape[-1]))
87+
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim)
88+
89+
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
90+
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
8091

81-
freqs_h = jnp.expand_dims(jnp.expand_dims(self.freqs_split[1][:pph], axis=0), axis=2)
82-
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, self.freqs_split[1].shape[-1]))
92+
freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2)
93+
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1]))
8394

84-
freqs_w = jnp.expand_dims(jnp.expand_dims(self.freqs_split[2][:ppw], axis=0), axis=1)
85-
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, self.freqs_split[2].shape[-1]))
95+
freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1)
96+
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1]))
8697

8798
freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1)
8899
freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1))
@@ -362,7 +373,7 @@ def __init__(
362373
qk_norm: Optional[str] = "rms_norm_across_heads",
363374
eps: float = 1e-6,
364375
image_dim: Optional[int] = None,
365-
added_kn_proj_dim: Optional[int] = None,
376+
added_kv_proj_dim: Optional[int] = None,
366377
rope_max_seq_len: int = 1024,
367378
pos_embed_seq_len: Optional[int] = None,
368379
flash_min_seq_length: int = 4096,

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import jax
1919
import jax.numpy as jnp
2020
from jax.sharding import Mesh, PositionalSharding
21+
import flax
22+
import flax.linen as nn
2123
from flax import nnx
2224
from ...pyconfig import HyperParameters
2325
from ... import max_logging
@@ -54,6 +56,48 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
5456
vs.sharding_rules = logical_axis_rules
5557
return vs
5658

59+
60+
partial(nnx.jit, static_argnums=(3,))
61+
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
62+
# breakpoint()
63+
def create_model(rngs: nnx.Rngs, wan_config: dict):
64+
wan_transformer = WanModel(**wan_config, rngs=rngs)
65+
return wan_transformer
66+
67+
wan_config = WanModel.load_config(
68+
config.pretrained_model_name_or_path,
69+
subfolder="transformer"
70+
)
71+
wan_config["mesh"] = mesh
72+
wan_config["dtype"] = config.activations_dtype
73+
wan_config["weights_dtype"] = config.weights_dtype
74+
wan_config["attention"] = config.attention
75+
p_model_factory = partial(create_model, wan_config=wan_config)
76+
wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs)
77+
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
78+
#breakpoint()
79+
logical_state_spec = nnx.get_partition_spec(state)
80+
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
81+
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
82+
params = state.to_pure_dict()
83+
state = dict(nnx.to_flat_state(state))
84+
# del state
85+
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
86+
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
87+
for path, val in flax.traverse_util.flatten_dict(params).items():
88+
sharding = logical_state_sharding[path].value
89+
state[path].value = jax.device_put(val, sharding)
90+
state = nnx.from_flat_state(state)
91+
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules)
92+
state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
93+
pspecs = nnx.get_partition_spec(state)
94+
#breakpoint()
95+
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
96+
#breakpoint()
97+
#wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state)
98+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
99+
return wan_transformer
100+
57101
partial(nnx.jit, static_argnums=(1,))
58102
def create_sharded_logical_model(model, logical_axis_rules):
59103
graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...)
@@ -154,26 +198,29 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
154198

155199
@classmethod
156200
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
157-
wan_transformer = WanModel.from_config(
158-
config.pretrained_model_name_or_path,
159-
subfolder="transformer",
160-
rngs=rngs,
161-
attention=config.attention,
162-
mesh=mesh,
163-
dtype=config.activations_dtype,
164-
weights_dtype=config.weights_dtype
165-
)
166-
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
167-
params = state.to_pure_dict()
168-
del state
169-
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
170-
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
171-
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
172-
wan_transformer = nnx.merge(graphdef, params, rest_of_state)
173-
# Shard
174-
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
175201
with mesh:
176-
wan_transformer = p_create_sharded_logical_model(model=wan_transformer)
202+
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
203+
# wan_transformer = WanModel.from_config(
204+
# config.pretrained_model_name_or_path,
205+
# subfolder="transformer",
206+
# rngs=rngs,
207+
# attention=config.attention,
208+
# mesh=mesh,
209+
# dtype=config.activations_dtype,
210+
# weights_dtype=config.weights_dtype
211+
# )
212+
# graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
213+
# breakpoint()
214+
# params = state.to_pure_dict()
215+
# del state
216+
# #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
217+
# params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
218+
# #params = jax.device_put(params, PositionalSharding(devices_array).replicate())
219+
# wan_transformer = nnx.merge(graphdef, params, rest_of_state)
220+
# # Shard
221+
# p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
222+
# with mesh:
223+
# wan_transformer = p_create_sharded_logical_model(model=wan_transformer)
177224
return wan_transformer
178225

179226
@classmethod

0 commit comments

Comments
 (0)