-
Notifications
You must be signed in to change notification settings - Fork 69
LTX-Vid Transformer Inference Step [WIP: Do Not Merge] #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
3776190
set up files for ltxvid
Serenagu525 13656fb
ltx-video-transformer-setup
Serenagu525 7bed4f9
formatting
Serenagu525 7e098c5
format fixed
Serenagu525 e18128c
transformer step and test
Serenagu525 1c55452
removed diffusers import
Serenagu525 fd4af91
fixed mesh
Serenagu525 5e17a62
changed path
Serenagu525 fc60b27
changed path
Serenagu525 3243535
changed config path
Serenagu525 e873a17
ruff check
Serenagu525 d06dee3
changed back pyconfig
Serenagu525 1ea6590
ruff check
Serenagu525 aa7befd
changed sharding back
Serenagu525 d9a3502
removed testing for now
Serenagu525 a1ad421
Update pyconfig.py
Serenagu525 615174f
Update max_utils.py
Serenagu525 7469c62
Update ltx_video.yml
Serenagu525 6de4424
Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred
Serenagu525 18ec247
Delete src/maxdiffusion/tests/ltx_transformer_step_test.py
Serenagu525 8a043f6
sharding back
Serenagu525 35a3337
added test
Serenagu525 546ecab
ruff fixed
Serenagu525 12a247f
added header
Serenagu525 1062c72
license headers
Serenagu525 535c75e
exclude test
Serenagu525 64b82c9
Update checkpointing_utils.py
Serenagu525 103db8f
Update max_utils.py
Serenagu525 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Copyright 2025 Google LLC | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| #hardware | ||
| hardware: 'tpu' | ||
| skip_jax_distributed_system: False | ||
|
|
||
| jax_cache_dir: '' | ||
| weights_dtype: 'bfloat16' | ||
| activations_dtype: 'bfloat16' | ||
|
|
||
|
|
||
| run_name: '' | ||
| output_dir: 'ltx-video-output' | ||
| save_config_to_gcs: False | ||
|
|
||
| #parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor'] | ||
| logical_axis_rules: [ | ||
| ['batch', 'data'], | ||
| ['activation_heads', 'fsdp'], | ||
| ['activation_batch', ['data','fsdp']], | ||
| ['activation_kv', 'tensor'], | ||
| ['mlp','tensor'], | ||
| ['embed','fsdp'], | ||
| ['heads', 'tensor'], | ||
| ['norm', 'fsdp'], | ||
| ['conv_batch', ['data','fsdp']], | ||
| ['out_channels', 'tensor'], | ||
| ['conv_out', 'fsdp'], | ||
| ['conv_in', 'fsdp'] | ||
| ] | ||
| data_sharding: [['data', 'fsdp', 'tensor']] | ||
| dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded | ||
| dcn_fsdp_parallelism: -1 | ||
| dcn_tensor_parallelism: 1 | ||
| ici_data_parallelism: 1 | ||
| ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded | ||
| ici_tensor_parallelism: 1 | ||
|
|
||
|
|
||
|
|
||
|
|
||
| learning_rate_schedule_steps: -1 | ||
| max_train_steps: 500 #TODO: change this | ||
| pretrained_model_name_or_path: '' | ||
| unet_checkpoint: '' | ||
| dataset_name: 'diffusers/pokemon-gpt4-captions' | ||
| train_split: 'train' | ||
| dataset_type: 'tf' | ||
| cache_latents_text_encoder_outputs: True | ||
| per_device_batch_size: 1 | ||
| compile_topology_num_slices: -1 | ||
| quantization_local_shard_count: -1 | ||
| jit_initializers: True | ||
| enable_single_replica_ckpt_restoring: False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| from absl import app | ||
| from typing import Sequence | ||
| import jax | ||
| import json | ||
| from flax.linen import partitioning as nn_partitioning | ||
| from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel | ||
| import os | ||
| import functools | ||
| import jax.numpy as jnp | ||
| from maxdiffusion import pyconfig | ||
| from maxdiffusion.max_utils import ( | ||
| create_device_mesh, | ||
| setup_initial_state, | ||
| get_memory_allocations, | ||
| ) | ||
| from jax.sharding import Mesh, PartitionSpec as P | ||
| import orbax.checkpoint as ocp | ||
|
|
||
|
|
||
| def validate_transformer_inputs( | ||
|
Serenagu525 marked this conversation as resolved.
|
||
| prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids | ||
| ): | ||
| print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) | ||
| print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) | ||
| print("latents.shape: ", latents.shape, latents.dtype) | ||
| print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) | ||
| print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) | ||
| print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) | ||
| print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) | ||
|
|
||
|
|
||
| def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): | ||
| latents, state, noise_cond = args | ||
| noise_pred = transformer.apply( | ||
| {"params": state.params}, | ||
| hidden_states=latents, | ||
| indices_grid=fractional_cords, | ||
| encoder_hidden_states=prompt_embeds, | ||
| timestep=noise_cond, | ||
| segment_ids=segment_ids, | ||
| encoder_attention_segment_ids=encoder_attention_segment_ids, | ||
| ) | ||
| return noise_pred, state, noise_cond | ||
|
|
||
|
|
||
| def run_inference( | ||
| states, | ||
| transformer, | ||
| config, | ||
| mesh, | ||
| latents, | ||
| fractional_cords, | ||
| prompt_embeds, | ||
| timestep, | ||
| segment_ids, | ||
| encoder_attention_segment_ids, | ||
| ): | ||
| transformer_state = states["transformer"] | ||
| loop_body_p = functools.partial( | ||
| loop_body, | ||
| transformer=transformer, | ||
| fractional_cords=fractional_cords, | ||
| prompt_embeds=prompt_embeds, | ||
| segment_ids=segment_ids, | ||
| encoder_attention_segment_ids=encoder_attention_segment_ids, | ||
| ) | ||
|
|
||
| with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): | ||
| noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) | ||
| return noise_pred | ||
|
|
||
|
|
||
| def run(config): | ||
| key = jax.random.PRNGKey(42) | ||
|
|
||
| devices_array = create_device_mesh(config) | ||
| mesh = Mesh(devices_array, config.mesh_axes) | ||
|
|
||
| base_dir = os.path.dirname(__file__) | ||
|
|
||
| ##load in model config | ||
| config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") | ||
| with open(config_path, "r") as f: | ||
| model_config = json.load(f) | ||
| relative_ckpt_path = model_config["ckpt_path"] | ||
|
|
||
| ignored_keys = [ | ||
| "_class_name", | ||
| "_diffusers_version", | ||
| "_name_or_path", | ||
| "causal_temporal_positioning", | ||
| "in_channels", | ||
| "ckpt_path", | ||
| ] | ||
| in_channels = model_config["in_channels"] | ||
| for name in ignored_keys: | ||
| if name in model_config: | ||
| del model_config[name] | ||
|
|
||
| transformer = Transformer3DModel( | ||
| **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh | ||
| ) | ||
| transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 | ||
| weights_init_fn = functools.partial( | ||
| transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True | ||
| ) | ||
|
|
||
| absolute_ckpt_path = os.path.abspath(relative_ckpt_path) | ||
|
|
||
| checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) | ||
| transformer_state, transformer_state_shardings = setup_initial_state( | ||
| model=transformer, | ||
| tx=None, | ||
| config=config, | ||
| mesh=mesh, | ||
| weights_init_fn=weights_init_fn, | ||
| checkpoint_manager=checkpoint_manager, | ||
| checkpoint_item=" ", | ||
| model_params=None, | ||
| training=False, | ||
| ) | ||
|
|
||
| transformer_state = jax.device_put(transformer_state, transformer_state_shardings) | ||
| get_memory_allocations() | ||
|
|
||
| states = {} | ||
| state_shardings = {} | ||
|
|
||
| state_shardings["transformer"] = transformer_state_shardings | ||
| states["transformer"] = transformer_state | ||
|
|
||
| # create dummy inputs: | ||
| example_inputs = {} | ||
| batch_size, num_tokens = 4, 256 | ||
| input_shapes = { | ||
| "latents": (batch_size, num_tokens, in_channels), | ||
| "fractional_coords": (batch_size, 3, num_tokens), | ||
| "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), | ||
| "timestep": (batch_size, 256), | ||
| "segment_ids": (batch_size, 256), | ||
| "encoder_attention_segment_ids": (batch_size, 128), | ||
| } | ||
| for name, shape in input_shapes.items(): | ||
| example_inputs[name] = jnp.ones( | ||
| shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool | ||
| ) | ||
|
|
||
| data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) | ||
| latents = jax.device_put(example_inputs["latents"], data_sharding) | ||
| prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) | ||
| fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) | ||
| noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) | ||
| segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) | ||
| encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) | ||
|
|
||
| validate_transformer_inputs( | ||
| prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids | ||
| ) | ||
| p_run_inference = jax.jit( | ||
| functools.partial( | ||
| run_inference, | ||
| transformer=transformer, | ||
| config=config, | ||
| mesh=mesh, | ||
| latents=latents, | ||
| fractional_cords=fractional_coords, | ||
| prompt_embeds=prompt_embeds, | ||
| timestep=noise_cond, | ||
| segment_ids=segment_ids, | ||
| encoder_attention_segment_ids=encoder_attention_segment_ids, | ||
| ), | ||
| in_shardings=(state_shardings,), | ||
| out_shardings=None, | ||
| ) | ||
|
|
||
| noise_pred = p_run_inference(states).block_until_ready() | ||
| print(noise_pred) # (4, 256, 128) | ||
|
|
||
|
|
||
| def main(argv: Sequence[str]) -> None: | ||
| pyconfig.initialize(argv) | ||
| run(pyconfig.config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
Serenagu525 marked this conversation as resolved.
|
Empty file.
|
Serenagu525 marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from enum import Enum, auto | ||
| from typing import Optional | ||
|
|
||
| import jax | ||
| from flax import linen as nn | ||
|
|
||
| SKIP_GRADIENT_CHECKPOINT_KEY = "skip" | ||
|
|
||
|
|
||
| class GradientCheckpointType(Enum): | ||
| """ | ||
| Defines the type of the gradient checkpoint we will have | ||
|
|
||
| NONE - means no gradient checkpoint | ||
| FULL - means full gradient checkpoint, wherever possible (minimum memory usage) | ||
| MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, | ||
| except for ones that involve batch dimension - that means that all attention and projection | ||
| layers will have gradient checkpoint, but not the backward with respect to the parameters | ||
| """ | ||
|
|
||
| NONE = auto() | ||
| FULL = auto() | ||
| MATMUL_WITHOUT_BATCH = auto() | ||
|
|
||
| @classmethod | ||
| def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": | ||
| """ | ||
| Constructs the gradient checkpoint type from a string | ||
|
|
||
| Args: | ||
| s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. | ||
|
|
||
| Returns: | ||
| GradientCheckpointType: The policy that corresponds to the string | ||
| """ | ||
| if s is None: | ||
| s = "none" | ||
| return GradientCheckpointType[s.upper()] | ||
|
|
||
| def to_jax_policy(self): | ||
| """ | ||
| Converts the gradient checkpoint type to a jax policy | ||
| """ | ||
| match self: | ||
| case GradientCheckpointType.NONE: | ||
| return SKIP_GRADIENT_CHECKPOINT_KEY | ||
| case GradientCheckpointType.FULL: | ||
| return None | ||
| case GradientCheckpointType.MATMUL_WITHOUT_BATCH: | ||
| return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims | ||
|
|
||
| def apply(self, module: nn.Module) -> nn.Module: | ||
| """ | ||
| Applies a gradient checkpoint policy to a module | ||
| if no policy is needed, it will return the module as is | ||
|
|
||
| Args: | ||
| module (nn.Module): the module to apply the policy to | ||
|
|
||
| Returns: | ||
| nn.Module: the module with the policy applied | ||
| """ | ||
| policy = self.to_jax_policy() | ||
| if policy == SKIP_GRADIENT_CHECKPOINT_KEY: | ||
| return module | ||
| return nn.remat( # pylint: disable=invalid-name | ||
| module, | ||
| prevent_cse=False, | ||
| policy=policy, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.