Skip to content

Commit 64729b5

Browse files
committed
formats and paths fixed
1 parent b2ef6ba commit 64729b5

25 files changed

Lines changed: 2458 additions & 4464 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def load_state_if_possible(
217217
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218218
else:
219219
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
221221

222222
def map_to_pspec(data):
223223
pspec = data.sharding.spec
Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from json import encoder
21
from absl import app
32
from typing import Sequence
43
import jax
5-
from flax import linen as nn
64
import json
75
from flax.linen import partitioning as nn_partitioning
86
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
@@ -19,7 +17,9 @@
1917
import orbax.checkpoint as ocp
2018

2119

22-
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids):
20+
def validate_transformer_inputs(
21+
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
22+
):
2323
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
2424
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
2525
print("latents.shape: ", latents.shape, latents.dtype)
@@ -29,15 +29,7 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise
2929
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
3030

3131

32-
def loop_body(
33-
step,
34-
args,
35-
transformer,
36-
fractional_cords,
37-
prompt_embeds,
38-
segment_ids,
39-
encoder_attention_segment_ids
40-
):
32+
def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids):
4133
latents, state, noise_cond = args
4234
noise_pred = transformer.apply(
4335
{"params": state.params},
@@ -46,14 +38,22 @@ def loop_body(
4638
encoder_hidden_states=prompt_embeds,
4739
timestep=noise_cond,
4840
segment_ids=segment_ids,
49-
encoder_attention_segment_ids=encoder_attention_segment_ids
41+
encoder_attention_segment_ids=encoder_attention_segment_ids,
5042
)
51-
return noise_pred, state, noise_cond
52-
43+
return noise_pred, state, noise_cond
5344

5445

5546
def run_inference(
56-
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
47+
states,
48+
transformer,
49+
config,
50+
mesh,
51+
latents,
52+
fractional_cords,
53+
prompt_embeds,
54+
timestep,
55+
segment_ids,
56+
encoder_attention_segment_ids,
5757
):
5858
transformer_state = states["transformer"]
5959
loop_body_p = functools.partial(
@@ -62,20 +62,19 @@ def run_inference(
6262
fractional_cords=fractional_cords,
6363
prompt_embeds=prompt_embeds,
6464
segment_ids=segment_ids,
65-
encoder_attention_segment_ids=encoder_attention_segment_ids
65+
encoder_attention_segment_ids=encoder_attention_segment_ids,
6666
)
67-
## TODO: add vae decode step
68-
## TODO: add loop
6967
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
70-
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
71-
return latents
72-
68+
noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
69+
return noise_pred
70+
71+
7372
def run(config):
74-
key = jax.random.PRNGKey(0)
73+
key = jax.random.PRNGKey(42)
7574

76-
devices_array = create_device_mesh(config)
75+
devices_array = create_device_mesh(config)
7776
mesh = Mesh(devices_array, config.mesh_axes)
78-
77+
7978
base_dir = os.path.dirname(__file__)
8079

8180
##load in model config
@@ -84,41 +83,42 @@ def run(config):
8483
model_config = json.load(f)
8584
relative_ckpt_path = model_config["ckpt_path"]
8685

87-
ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"]
86+
ignored_keys = [
87+
"_class_name",
88+
"_diffusers_version",
89+
"_name_or_path",
90+
"causal_temporal_positioning",
91+
"in_channels",
92+
"ckpt_path",
93+
]
8894
in_channels = model_config["in_channels"]
8995
for name in ignored_keys:
9096
if name in model_config:
9197
del model_config[name]
92-
93-
94-
transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
95-
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True)
96-
98+
99+
transformer = Transformer3DModel(
100+
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh
101+
)
102+
transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841
97103
weights_init_fn = functools.partial(
98-
transformer.init_weights,
99-
in_channels,
100-
model_config['caption_channels'],
101-
eval_only = True
104+
transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True
102105
)
103106

104107
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
105108

106109
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
107110
transformer_state, transformer_state_shardings = setup_initial_state(
108-
model=transformer,
109-
tx=None,
110-
config=config,
111-
mesh=mesh,
112-
weights_init_fn=weights_init_fn,
113-
checkpoint_manager=checkpoint_manager,
114-
checkpoint_item=" ",
115-
model_params=None,
116-
training=False,
111+
model=transformer,
112+
tx=None,
113+
config=config,
114+
mesh=mesh,
115+
weights_init_fn=weights_init_fn,
116+
checkpoint_manager=checkpoint_manager,
117+
checkpoint_item=" ",
118+
model_params=None,
119+
training=False,
117120
)
118121

119-
120-
121-
122122
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
123123
get_memory_allocations()
124124

@@ -128,20 +128,20 @@ def run(config):
128128
state_shardings["transformer"] = transformer_state_shardings
129129
states["transformer"] = transformer_state
130130

131-
#create dummy inputs:
131+
# create dummy inputs:
132132
example_inputs = {}
133133
batch_size, num_tokens = 4, 256
134134
input_shapes = {
135-
"latents": (batch_size, num_tokens, in_channels),
136-
"fractional_coords": (batch_size, 3, num_tokens),
137-
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
138-
"timestep": (batch_size, 256),
139-
"segment_ids": (batch_size, 256),
140-
"encoder_attention_segment_ids": (batch_size, 128),
135+
"latents": (batch_size, num_tokens, in_channels),
136+
"fractional_coords": (batch_size, 3, num_tokens),
137+
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
138+
"timestep": (batch_size, 256),
139+
"segment_ids": (batch_size, 256),
140+
"encoder_attention_segment_ids": (batch_size, 128),
141141
}
142142
for name, shape in input_shapes.items():
143143
example_inputs[name] = jnp.ones(
144-
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
144+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
145145
)
146146

147147
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
@@ -152,7 +152,9 @@ def run(config):
152152
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
153153
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
154154

155-
validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids)
155+
validate_transformer_inputs(
156+
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
157+
)
156158
p_run_inference = jax.jit(
157159
functools.partial(
158160
run_inference,
@@ -162,16 +164,16 @@ def run(config):
162164
latents=latents,
163165
fractional_cords=fractional_coords,
164166
prompt_embeds=prompt_embeds,
165-
timestep = noise_cond,
167+
timestep=noise_cond,
166168
segment_ids=segment_ids,
167-
encoder_attention_segment_ids=encoder_attention_segment_ids
169+
encoder_attention_segment_ids=encoder_attention_segment_ids,
168170
),
169171
in_shardings=(state_shardings,),
170172
out_shardings=None,
171173
)
172174

173175
noise_pred = p_run_inference(states).block_until_ready()
174-
print(noise_pred) #(4, 256, 128)
176+
print(noise_pred) # (4, 256, 128)
175177

176178

177179
def main(argv: Sequence[str]) -> None:
@@ -181,18 +183,3 @@ def main(argv: Sequence[str]) -> None:
181183

182184
if __name__ == "__main__":
183185
app.run(main)
184-
185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-

src/maxdiffusion/max_utils.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,24 +251,61 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
251251

252252
return parallelism_vals
253253

254-
def create_device_mesh(config, devices=None):
254+
255+
def create_device_mesh(config, devices=None, logging=True):
255256
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
256257
if devices is None:
257258
devices = jax.devices()
258259
num_devices = len(devices)
259-
num_slices = 1
260+
##special case for ltx-video
261+
if config.ici_fsdp_transpose_parallelism:
262+
num_slices = 1
263+
# if config.inference_benchmark_test else config.num_slices
264+
num_devices_per_slice = num_devices // num_slices
265+
# Find possible unspecified parallelisms
266+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267+
mesh = mesh_utils.create_device_mesh(
268+
ici_parallelism,
269+
devices,
270+
)
271+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272+
273+
return mesh
274+
275+
try:
276+
num_slices = 1 + max([d.slice_index for d in devices])
277+
except:
278+
num_slices = 1
260279
num_devices_per_slice = num_devices // num_slices
280+
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
281+
282+
multi_slice_env = num_slices > 1
283+
284+
dcn_parallelism = [
285+
config.dcn_data_parallelism,
286+
config.dcn_fsdp_parallelism,
287+
config.dcn_tensor_parallelism,
288+
]
289+
ici_parallelism = [
290+
config.ici_data_parallelism,
291+
config.ici_fsdp_parallelism,
292+
config.ici_tensor_parallelism,
293+
]
261294

262295
# Find possible unspecified parallelisms
263-
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
264-
mesh = mesh_utils.create_device_mesh(
265-
ici_parallelism,
266-
devices,
267-
)
268-
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
296+
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
297+
if multi_slice_env:
298+
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
299+
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
300+
else:
301+
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
302+
303+
if logging:
304+
max_logging.log(f"Decided on mesh: {mesh}")
269305

270306
return mesh
271307

308+
272309
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
273310
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
274311
@@ -380,6 +417,7 @@ def setup_initial_state(
380417
config.enable_single_replica_ckpt_restoring,
381418
)
382419
if state:
420+
###!Edited
383421
if checkpoint_item == " ":
384422
state = state
385423
else:
@@ -590,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
590628
initialize_jax_for_gpu()
591629
max_logging.log("Jax distributed system initialized on GPU!")
592630
else:
593-
jax.distributed.initialize()
631+
jax.distributed.initialize()

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,4 +1188,4 @@ def setup(self):
11881188
def __call__(self, hidden_states, deterministic=True):
11891189
hidden_states = self.proj(hidden_states)
11901190
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1191-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1191+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)