diff --git a/README.md b/README.md index e3d7ab075..4d776ca21 100644 --- a/README.md +++ b/README.md @@ -177,11 +177,11 @@ To generate images, run the following command: ## LTX-Video - In the folder src/maxdiffusion/models/ltx_video/utils, run: ```bash - python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json + python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../ltxv-13B.json ``` - In the repo folder, run: ```bash - python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json" + python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json" ``` - Other generation parameters can be set in ltx_video.yml file. ## Flux diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/ltxv-13B.json similarity index 100% rename from src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json rename to src/maxdiffusion/models/ltx_video/ltxv-13B.json diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 0ca816f9e..1e0abe698 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -60,12 +60,10 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids): # Note: reference shape annotated for first pass default inference parameters - max_logging.log("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32 - max_logging.log("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32 - max_logging.log("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32 - max_logging.log( - "encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype - ) # (3, 256) int32 + max_logging.log(f"prompts_embeds.shape: {prompt_embeds.shape}") # (3, 256, 4096) float32 + max_logging.log(f"fractional_coords.shape: {fractional_coords.shape}") # (3, 3, 3072) float32 + max_logging.log(f"latents.shape: {latents.shape}") # (1, 3072, 128) float 32 + max_logging.log(f"encoder_attention_segment_ids.shape: {encoder_attention_segment_ids.shape}") # (3, 256) int32 class LTXVideoPipeline: