-
Notifications
You must be signed in to change notification settings - Fork 69
LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge] #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
3776190
set up files for ltxvid
Serenagu525 13656fb
ltx-video-transformer-setup
Serenagu525 7bed4f9
formatting
Serenagu525 b31a97b
conversion script added
Serenagu525 9a9f5db
conversion script checked
Serenagu525 d1c304d
comments removed
Serenagu525 f93c3bd
Added running instructions
Serenagu525 e0327e5
edited instruction
Serenagu525 c369302
ruff check error fixed
Serenagu525 991a44e
mesh edit
Serenagu525 b0e9bab
key error fix
Serenagu525 2737877
added header
Serenagu525 f6115df
auto script
Serenagu525 8bf24a3
headers
Serenagu525 0f8483e
pulled
Serenagu525 eaa7196
auto script for file downloading
Serenagu525 e805034
Update max_utils.py
Serenagu525 36242d2
changed input format
Serenagu525 8fc3626
Merge branch 'conversion-script' of https://github.com/AI-Hypercomput…
Serenagu525 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # Copyright 2025 Google LLC | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| #hardware | ||
| hardware: 'tpu' | ||
| skip_jax_distributed_system: False | ||
|
|
||
| jax_cache_dir: '' | ||
| weights_dtype: 'bfloat16' | ||
| activations_dtype: 'bfloat16' | ||
|
|
||
|
|
||
| run_name: '' | ||
| output_dir: 'ltx-video-output' | ||
| save_config_to_gcs: False | ||
|
|
||
| #parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] | ||
| logical_axis_rules: [ | ||
| ['batch', 'data'], | ||
| ['activation_batch', ['data','fsdp']], | ||
| ['activation_heads', 'tensor'], | ||
| ['activation_kv', 'tensor'], | ||
| ['mlp','tensor'], | ||
| ['embed','fsdp'], | ||
| ['heads', 'tensor'], | ||
| ['conv_batch', ['data','fsdp']], | ||
| ['out_channels', 'tensor'], | ||
| ['conv_out', 'fsdp'], | ||
| ] | ||
| data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] | ||
| dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded | ||
| dcn_fsdp_parallelism: -1 | ||
| dcn_tensor_parallelism: 1 | ||
| ici_data_parallelism: -1 | ||
| ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded | ||
| ici_tensor_parallelism: 1 | ||
|
|
||
| ici_fsdp_transpose_parallelism: 1 | ||
| ici_sequence_parallelism: 1 | ||
| ici_tensor_transpose_parallelism: 1 | ||
| ici_expert_parallelism: 1 | ||
| ici_sequence_parallelism: 1 | ||
|
|
||
|
|
||
|
|
||
|
|
||
| learning_rate_schedule_steps: -1 | ||
| max_train_steps: 500 #TODO: change this | ||
| pretrained_model_name_or_path: '' | ||
| unet_checkpoint: '' | ||
| dataset_name: 'diffusers/pokemon-gpt4-captions' | ||
| train_split: 'train' | ||
| dataset_type: 'tf' | ||
| cache_latents_text_encoder_outputs: True | ||
| per_device_batch_size: 1 | ||
| compile_topology_num_slices: -1 | ||
| quantization_local_shard_count: -1 | ||
| jit_initializers: True | ||
| enable_single_replica_ckpt_restoring: False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| """ | ||
| Copyright 2025 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| from absl import app | ||
| from typing import Sequence | ||
| import jax | ||
| import json | ||
| from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel | ||
| import os | ||
| import functools | ||
| import jax.numpy as jnp | ||
| from maxdiffusion import pyconfig | ||
| from maxdiffusion.max_utils import ( | ||
| create_device_mesh, | ||
| setup_initial_state, | ||
| get_memory_allocations, | ||
| ) | ||
| from jax.sharding import Mesh | ||
| import orbax.checkpoint as ocp | ||
|
|
||
|
|
||
| def validate_transformer_inputs( | ||
| prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids | ||
| ): | ||
| print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) | ||
| print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) | ||
| print("latents.shape: ", latents.shape, latents.dtype) | ||
| print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) | ||
| print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) | ||
| print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) | ||
| print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) | ||
|
|
||
|
|
||
| def run(config): | ||
|
|
||
| key = jax.random.PRNGKey(42) | ||
|
|
||
| devices_array = create_device_mesh(config) | ||
| mesh = Mesh(devices_array, config.mesh_axes) | ||
|
|
||
| base_dir = os.path.dirname(__file__) | ||
|
|
||
| ##load in model config | ||
| config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") | ||
| with open(config_path, "r") as f: | ||
| model_config = json.load(f) | ||
| ckpt_path = model_config["ckpt_path"] | ||
|
|
||
| ignored_keys = [ | ||
| "_class_name", | ||
| "_diffusers_version", | ||
| "_name_or_path", | ||
| "causal_temporal_positioning", | ||
| "in_channels", | ||
| "ckpt_path", | ||
| ] | ||
| in_channels = model_config["in_channels"] | ||
| for name in ignored_keys: | ||
| if name in model_config: | ||
| del model_config[name] | ||
|
|
||
| transformer = Transformer3DModel( | ||
| **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh | ||
| ) | ||
| transformer_param_shapes = transformer.init_weights( # noqa: F841 | ||
| in_channels, key, model_config["caption_channels"], eval_only=True | ||
| ) | ||
| weights_init_fn = functools.partial( | ||
| transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True | ||
| ) | ||
|
|
||
| checkpoint_manager = ocp.CheckpointManager(ckpt_path) | ||
| transformer_state, transformer_state_shardings = setup_initial_state( | ||
| model=transformer, | ||
| tx=None, | ||
| config=config, | ||
| mesh=mesh, | ||
| weights_init_fn=weights_init_fn, | ||
| checkpoint_manager=checkpoint_manager, | ||
| checkpoint_item=" ", | ||
| model_params=None, | ||
| training=False, | ||
| ) | ||
|
|
||
| transformer_state = jax.device_put(transformer_state, transformer_state_shardings) | ||
| get_memory_allocations() | ||
|
|
||
| states = {} | ||
| state_shardings = {} | ||
|
|
||
| state_shardings["transformer"] = transformer_state_shardings | ||
| states["transformer"] = transformer_state | ||
|
|
||
|
|
||
| def main(argv: Sequence[str]) -> None: | ||
| pyconfig.initialize(argv) | ||
| run(pyconfig.config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
|
|
||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment as Juan from previous PR, why is checkpoint == " "
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if checkpoint set to None, cannot pass the check "if checkpoint_manager and checkpoint_item:" in max_utils.py. So I set it to empty string to get around this