Skip to content

LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge]#193

Closed
Serenagu525 wants to merge 19 commits intomainfrom
conversion-script
Closed

LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge]#193
Serenagu525 wants to merge 19 commits intomainfrom
conversion-script

Conversation

@Serenagu525
Copy link
Copy Markdown
Contributor

@Serenagu525 Serenagu525 commented Jun 27, 2025

Converts ltxv-13b-0.9.7-dev.safetensors from lightricks huggingface into JAX weight checkpoint.
See running instruction at https://github.com/AI-Hypercomputer/maxdiffusion/blob/conversion-script/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md

@github-actions
Copy link
Copy Markdown

try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
if checkpoint_item == " ":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment as Juan from previous PR, why is checkpoint == " "

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if checkpoint set to None, cannot pass the check "if checkpoint_manager and checkpoint_item:" in max_utils.py. So I set it to empty string to get around this

axis = _normalize_axes(axis, inputs.ndim)

kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
# kernel_in_axis = np.arange(len(axis))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented lines

return t_emb


class AlphaCombinedTimestepSizeEmbeddings(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complete the docstring?

@Serenagu525 Serenagu525 changed the title LTXVid Transformer Pytorch-Jax Conversion script LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge Aug 5, 2025
@Serenagu525 Serenagu525 changed the title LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge LTXVid Transformer Pytorch-Jax Conversion script [WIP: Do Not Merge] Aug 5, 2025
@Serenagu525 Serenagu525 closed this Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants