Skip to content

Commit ed30ace

Browse files
committed
one transformer inference pass done
1 parent 60ac6c6 commit ed30ace

5 files changed

Lines changed: 146 additions & 54 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ output_dir: 'ltx-video-output'
1212
save_config_to_gcs: False
1313

1414
#parallelism
15-
mesh_axes: ['data', 'fsdp', 'tensor']
15+
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
1616
logical_axis_rules: [
1717
['batch', 'data'],
1818
['activation_batch', ['data','fsdp']],
@@ -25,13 +25,19 @@ logical_axis_rules: [
2525
['out_channels', 'tensor'],
2626
['conv_out', 'fsdp'],
2727
]
28-
data_sharding: [['data', 'fsdp', 'tensor']]
28+
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
2929
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
3030
dcn_fsdp_parallelism: -1
3131
dcn_tensor_parallelism: 1
32+
3233
ici_data_parallelism: -1
3334
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
3435
ici_tensor_parallelism: 1
36+
ici_fsdp_transpose_parallelism: 1
37+
ici_sequence_parallelism: 1
38+
ici_tensor_transpose_parallelism: 1
39+
ici_expert_parallelism: 1
40+
ici_sequence_parallelism: 1
3541

3642

3743

@@ -48,4 +54,4 @@ per_device_batch_size: 1
4854
compile_topology_num_slices: -1
4955
quantization_local_shard_count: -1
5056
jit_initializers: True
51-
enable_single_replica_ckpt_restoring: False
57+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/generate_ltx_video.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,48 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise
2828
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
2929
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
3030

31+
32+
def loop_body(
33+
step,
34+
args,
35+
transformer,
36+
fractional_cords,
37+
prompt_embeds,
38+
segment_ids,
39+
encoder_attention_segment_ids
40+
):
41+
latents, state, noise_cond = args
42+
noise_pred = transformer.apply(
43+
{"params": state.params},
44+
hidden_states=latents,
45+
indices_grid=fractional_cords,
46+
encoder_hidden_states=prompt_embeds,
47+
timestep=noise_cond,
48+
segment_ids=segment_ids,
49+
encoder_attention_segment_ids=encoder_attention_segment_ids
50+
)
51+
return noise_pred, state, noise_cond
52+
53+
54+
55+
def run_inference(
56+
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
57+
):
58+
transformer_state = states["transformer"]
59+
loop_body_p = functools.partial(
60+
loop_body,
61+
transformer=transformer,
62+
fractional_cords=fractional_cords,
63+
prompt_embeds=prompt_embeds,
64+
segment_ids=segment_ids,
65+
encoder_attention_segment_ids=encoder_attention_segment_ids
66+
)
67+
## TODO: add vae decode step
68+
## TODO: add loop
69+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
70+
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
71+
return latents
72+
3173
def run(config):
3274
key = jax.random.PRNGKey(0)
3375

@@ -50,7 +92,7 @@ def run(config):
5092

5193

5294
transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
53-
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True)
95+
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True)
5496

5597
weights_init_fn = functools.partial(
5698
transformer.init_weights,
@@ -75,7 +117,61 @@ def run(config):
75117
)
76118

77119

120+
121+
122+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
123+
get_memory_allocations()
124+
125+
states = {}
126+
state_shardings = {}
127+
128+
state_shardings["transformer"] = transformer_state_shardings
129+
states["transformer"] = transformer_state
130+
131+
#create dummy inputs:
132+
example_inputs = {}
133+
batch_size, num_tokens = 4, 256
134+
input_shapes = {
135+
"latents": (batch_size, num_tokens, in_channels),
136+
"fractional_coords": (batch_size, 3, num_tokens),
137+
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
138+
"timestep": (batch_size, 256),
139+
"segment_ids": (batch_size, 256),
140+
"encoder_attention_segment_ids": (batch_size, 128),
141+
}
142+
for name, shape in input_shapes.items():
143+
example_inputs[name] = jnp.ones(
144+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
145+
)
146+
147+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
148+
latents = jax.device_put(example_inputs["latents"], data_sharding)
149+
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
150+
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
151+
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
152+
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
153+
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
154+
155+
validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids)
156+
p_run_inference = jax.jit(
157+
functools.partial(
158+
run_inference,
159+
transformer=transformer,
160+
config=config,
161+
mesh=mesh,
162+
latents=latents,
163+
fractional_cords=fractional_coords,
164+
prompt_embeds=prompt_embeds,
165+
timestep = noise_cond,
166+
segment_ids=segment_ids,
167+
encoder_attention_segment_ids=encoder_attention_segment_ids
168+
),
169+
in_shardings=(state_shardings,),
170+
out_shardings=None,
171+
)
78172

173+
noise_pred = p_run_inference(states).block_until_ready()
174+
print(noise_pred) #(4, 256, 128)
79175

80176

81177
def main(argv: Sequence[str]) -> None:
@@ -89,4 +185,14 @@ def main(argv: Sequence[str]) -> None:
89185

90186

91187

188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
92198

src/maxdiffusion/max_utils.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -251,46 +251,24 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
251251

252252
return parallelism_vals
253253

254-
255-
def create_device_mesh(config, devices=None, logging=True):
254+
def create_device_mesh(config, devices=None):
256255
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257256
if devices is None:
258257
devices = jax.devices()
259258
num_devices = len(devices)
260-
try:
261-
num_slices = 1 + max([d.slice_index for d in devices])
262-
except:
263-
num_slices = 1
259+
num_slices = 1
264260
num_devices_per_slice = num_devices // num_slices
265-
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
266-
267-
multi_slice_env = num_slices > 1
268-
269-
dcn_parallelism = [
270-
config.dcn_data_parallelism,
271-
config.dcn_fsdp_parallelism,
272-
config.dcn_tensor_parallelism,
273-
]
274-
ici_parallelism = [
275-
config.ici_data_parallelism,
276-
config.ici_fsdp_parallelism,
277-
config.ici_tensor_parallelism,
278-
]
279261

280262
# Find possible unspecified parallelisms
281-
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
282-
if multi_slice_env:
283-
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
284-
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
285-
else:
286-
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
287-
288-
if logging:
289-
max_logging.log(f"Decided on mesh: {mesh}")
263+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
264+
mesh = mesh_utils.create_device_mesh(
265+
ici_parallelism,
266+
devices,
267+
)
268+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
290269

291270
return mesh
292271

293-
294272
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
295273
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
296274
@@ -612,4 +590,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
612590
initialize_jax_for_gpu()
613591
max_logging.log("Jax distributed system initialized on GPU!")
614592
else:
615-
jax.distributed.initialize()
593+
jax.distributed.initialize()

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -631,27 +631,13 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
631631
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
632632
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
633633
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
634-
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
635-
# ("data", "fsdp", "fsdp_transpose", "expert"),
636-
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
637-
# None,
638-
# None,
639-
# )
640-
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
641-
# ("data", "fsdp", "fsdp_transpose", "expert"),
642-
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
643-
# None,
644-
# None,
645-
# )
646634
qkvo_sharding_spec = jax.sharding.PartitionSpec(
647-
None,
648-
None,
635+
("data", "fsdp", "fsdp_transpose", "expert"),
636+
("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
649637
None,
650638
None,
651639
)
652-
#Based on: ("activation_kv_batch", "activation_length")
653-
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
654-
qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
640+
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
655641
wrapped_flash_attention = shard_map(
656642
partial_flash_attention,
657643
mesh=sharding_mesh,
@@ -910,4 +896,4 @@ def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Ar
910896
# Apply rotary embeddings
911897
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
912898

913-
return out
899+
return out

src/maxdiffusion/pyconfig.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool:
4141
config = None
4242

4343

44+
def create_parallelisms_list(raw_keys):
45+
ici_parallelism = [
46+
raw_keys["ici_data_parallelism"],
47+
raw_keys["ici_fsdp_parallelism"],
48+
raw_keys["ici_fsdp_transpose_parallelism"],
49+
raw_keys["ici_sequence_parallelism"],
50+
raw_keys["ici_tensor_parallelism"],
51+
raw_keys["ici_tensor_transpose_parallelism"],
52+
raw_keys["ici_expert_parallelism"],
53+
raw_keys["ici_sequence_parallelism"],
54+
]
55+
raw_keys["ici_parallelism"] = ici_parallelism
56+
return raw_keys
57+
58+
4459
def print_system_information():
4560
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
4661
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -154,6 +169,7 @@ def user_init(raw_keys):
154169
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
155170
raw_keys["num_slices"] = get_num_slices(raw_keys)
156171
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172+
raw_keys = create_parallelisms_list(raw_keys)
157173

158174

159175
def get_num_slices(raw_keys):
@@ -204,4 +220,4 @@ def initialize(argv, **kwargs):
204220
if __name__ == "__main__":
205221
initialize(sys.argv)
206222
print(config.steps)
207-
r = range(config.steps)
223+
r = range(config.steps)

0 commit comments

Comments
 (0)