Skip to content

Commit 9a9f5db

Browse files
committed
conversion script checked
1 parent b31a97b commit 9a9f5db

9 files changed

Lines changed: 218 additions & 101 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,10 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
216+
if checkpoint_item == " ":
217+
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218+
else:
219+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217220
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
218221

219222
def map_to_pspec(data):

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ output_dir: 'ltx-video-output'
2727
save_config_to_gcs: False
2828

2929
#parallelism
30-
mesh_axes: ['data', 'fsdp', 'tensor']
30+
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
3131
logical_axis_rules: [
3232
['batch', 'data'],
3333
['activation_batch', ['data','fsdp']],
@@ -40,14 +40,20 @@ logical_axis_rules: [
4040
['out_channels', 'tensor'],
4141
['conv_out', 'fsdp'],
4242
]
43-
data_sharding: [['data', 'fsdp', 'tensor']]
43+
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
4444
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
4545
dcn_fsdp_parallelism: -1
4646
dcn_tensor_parallelism: 1
4747
ici_data_parallelism: -1
4848
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
4949
ici_tensor_parallelism: 1
5050

51+
ici_fsdp_transpose_parallelism: 1
52+
ici_sequence_parallelism: 1
53+
ici_tensor_transpose_parallelism: 1
54+
ici_expert_parallelism: 1
55+
ici_sequence_parallelism: 1
56+
5157

5258

5359

@@ -62,4 +68,5 @@ cache_latents_text_encoder_outputs: True
6268
per_device_batch_size: 1
6369
compile_topology_num_slices: -1
6470
quantization_local_shard_count: -1
65-
jit_initializers: True
71+
jit_initializers: True
72+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/generate_ltx_video.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,90 @@
2020
import json
2121
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
2222
import os
23+
import functools
2324
import jax.numpy as jnp
2425
from maxdiffusion import pyconfig
2526
from maxdiffusion.max_utils import (
2627
create_device_mesh,
28+
setup_initial_state,
29+
get_memory_allocations,
2730
)
31+
from jax.sharding import Mesh
32+
import orbax.checkpoint as ocp
2833

2934

30-
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
35+
def validate_transformer_inputs(
36+
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
37+
):
3138
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
3239
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
3340
print("latents.shape: ", latents.shape, latents.dtype)
3441
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
42+
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
43+
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
44+
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
3545

3646

3747
def run(config):
38-
key = jax.random.PRNGKey(0)
48+
49+
key = jax.random.PRNGKey(42)
3950

4051
devices_array = create_device_mesh(config)
4152
mesh = Mesh(devices_array, config.mesh_axes)
4253

43-
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
4454
base_dir = os.path.dirname(__file__)
4555

46-
# load in model config
56+
##load in model config
4757
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
4858
with open(config_path, "r") as f:
4959
model_config = json.load(f)
60+
ckpt_path = model_config["ckpt_path"]
61+
62+
ignored_keys = [
63+
"_class_name",
64+
"_diffusers_version",
65+
"_name_or_path",
66+
"causal_temporal_positioning",
67+
"in_channels",
68+
"ckpt_path",
69+
]
70+
in_channels = model_config["in_channels"]
71+
for name in ignored_keys:
72+
if name in model_config:
73+
del model_config[name]
74+
75+
transformer = Transformer3DModel(
76+
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh
77+
)
78+
transformer_param_shapes = transformer.init_weights( # noqa: F841
79+
in_channels, key, model_config["caption_channels"], eval_only=True
80+
) # use this to test!
5081

51-
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
52-
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False)
82+
weights_init_fn = functools.partial(
83+
transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True
84+
)
5385

54-
key, split_key = jax.random.split(key)
86+
checkpoint_manager = ocp.CheckpointManager(ckpt_path)
87+
transformer_state, transformer_state_shardings = setup_initial_state(
88+
model=transformer,
89+
tx=None,
90+
config=config,
91+
mesh=mesh,
92+
weights_init_fn=weights_init_fn,
93+
checkpoint_manager=checkpoint_manager,
94+
checkpoint_item=" ",
95+
model_params=None,
96+
training=False,
97+
)
5598

99+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
100+
get_memory_allocations()
56101

57-
weights_init_fn = functools.partial(
58-
transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True
59-
)
102+
states = {}
103+
state_shardings = {}
104+
105+
state_shardings["transformer"] = transformer_state_shardings
106+
states["transformer"] = transformer_state
60107

61108

62109
def main(argv: Sequence[str]) -> None:
@@ -66,3 +113,9 @@ def main(argv: Sequence[str]) -> None:
66113

67114
if __name__ == "__main__":
68115
app.run(main)
116+
117+
118+
###setup_initial_state, can optionally load from checkpoint
119+
120+
121+
# end to end steps from ltx repo: pipeline_ltx_video.py

src/maxdiffusion/max_utils.py

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -252,45 +252,88 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
252252
return parallelism_vals
253253

254254

255-
def create_device_mesh(config, devices=None, logging=True):
255+
def create_device_mesh(config, devices=None):
256256
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260-
try:
261-
num_slices = 1 + max([d.slice_index for d in devices])
262-
except:
263-
num_slices = 1
260+
num_slices = 1
261+
# if config.inference_benchmark_test else config.num_slices
264262
num_devices_per_slice = num_devices // num_slices
265-
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
266263

267-
multi_slice_env = num_slices > 1
268-
269-
dcn_parallelism = [
270-
config.dcn_data_parallelism,
271-
config.dcn_fsdp_parallelism,
272-
config.dcn_tensor_parallelism,
273-
]
274-
ici_parallelism = [
275-
config.ici_data_parallelism,
276-
config.ici_fsdp_parallelism,
277-
config.ici_tensor_parallelism,
278-
]
264+
# multi_slice_env = num_slices > 1
279265

280266
# Find possible unspecified parallelisms
281-
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
282-
if multi_slice_env:
283-
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
284-
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
285-
else:
286-
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
287-
288-
if logging:
289-
max_logging.log(f"Decided on mesh: {mesh}")
267+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
268+
269+
# allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False
270+
271+
# if allow_split_physical_axes:
272+
# if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh):
273+
# mesh = mesh_utils.create_device_mesh(
274+
# [16, 16],
275+
# devices,
276+
# contiguous_submeshes=False,
277+
# allow_split_physical_axes=False,
278+
# )
279+
# mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh)
280+
# mesh = np.reshape(mesh, ici_parallelism)
281+
# else:
282+
# mesh = mesh_utils.create_device_mesh(
283+
# ici_parallelism,
284+
# devices,
285+
# contiguous_submeshes=False,
286+
# allow_split_physical_axes=allow_split_physical_axes,
287+
# )
288+
# else:
289+
mesh = mesh_utils.create_device_mesh(
290+
ici_parallelism,
291+
devices,
292+
)
293+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
290294

291295
return mesh
292296

293297

298+
# def create_device_mesh(config, devices=None, logging=True):
299+
# """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
300+
# if devices is None:
301+
# devices = jax.devices()
302+
# num_devices = len(devices)
303+
# try:
304+
# num_slices = 1 + max([d.slice_index for d in devices])
305+
# except:
306+
# num_slices = 1
307+
# num_devices_per_slice = num_devices // num_slices
308+
# max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
309+
310+
# multi_slice_env = num_slices > 1
311+
312+
# dcn_parallelism = [
313+
# config.dcn_data_parallelism,
314+
# config.dcn_fsdp_parallelism,
315+
# config.dcn_tensor_parallelism,
316+
# ]
317+
# ici_parallelism = [
318+
# config.ici_data_parallelism,
319+
# config.ici_fsdp_parallelism,
320+
# config.ici_tensor_parallelism,
321+
# ]
322+
323+
# # Find possible unspecified parallelisms
324+
# ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
325+
# if multi_slice_env:
326+
# dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
327+
# mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
328+
# else:
329+
# mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
330+
331+
# if logging:
332+
# max_logging.log(f"Decided on mesh: {mesh}")
333+
334+
# return mesh
335+
336+
294337
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
295338
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
296339
@@ -402,7 +445,10 @@ def setup_initial_state(
402445
config.enable_single_replica_ckpt_restoring,
403446
)
404447
if state:
405-
state = state[checkpoint_item]
448+
if checkpoint_item == " ":
449+
state = state
450+
else:
451+
state = state[checkpoint_item]
406452
if not state:
407453
max_logging.log(f"Could not find the item in orbax, creating state...")
408454
init_train_state_partial = functools.partial(

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
from typing import Any, Dict, Optional, Tuple
44
from enum import Enum, auto
5-
65
import jax
76
import jax.nn as jnn
87
import jax.numpy as jnp
@@ -198,8 +197,7 @@ def __call__(
198197

199198
# Adaptive Norm
200199
if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
201-
# [batch, 1 or num_tokens, embedding_dim]
202-
assert timestep.ndim == 3
200+
assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
203201
num_ada_params = self.scale_shift_table.shape[0]
204202
ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape(
205203
batch_size, timestep.shape[1], num_ada_params, -1
@@ -438,7 +436,7 @@ def __call__(
438436
deterministic: bool = True,
439437
**cross_attention_kwargs,
440438
) -> jnp.ndarray:
441-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
439+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821
442440
assert cross_attention_kwargs.get("scale", None) is None, "Not supported"
443441

444442
input_axis_names = ("activation_batch", "activation_length", "activation_embed")
@@ -628,8 +626,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
628626
None,
629627
None,
630628
)
629+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
630+
# ("data", "fsdp", "fsdp_transpose", "expert"),
631+
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
632+
# None,
633+
# None,
634+
# )
635+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
636+
# None,
637+
# None,
638+
# None,
639+
# None,
640+
# )
631641
# Based on: ("activation_kv_batch", "activation_length")
632642
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
643+
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
633644
wrapped_flash_attention = shard_map(
634645
partial_flash_attention,
635646
mesh=sharding_mesh,
@@ -814,8 +825,7 @@ def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic:
814825
inner_dim = dim * self.mult
815826
if inner_dim < 256:
816827
raise ValueError("inner_dim must be at least 256")
817-
# round to nearest multiple of 256
818-
inner_dim = round(inner_dim / 256) * 256
828+
inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256
819829
else:
820830
inner_dim = self.inner_dim
821831

0 commit comments

Comments
 (0)