From 377619088bcd5b04fcab5413b3dfd740c32eb27a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 19:05:46 +0000 Subject: [PATCH 01/69] set up files for ltxvid --- src/maxdiffusion/__init__.py | 721 +++++++++++++------------ src/maxdiffusion/configs/ltx_video.yml | 50 ++ src/maxdiffusion/generate_ltx_video.py | 73 +++ src/maxdiffusion/models/__init__.py | 5 +- 4 files changed, 490 insertions(+), 359 deletions(-) create mode 100644 src/maxdiffusion/configs/ltx_video.yml create mode 100644 src/maxdiffusion/generate_ltx_video.py diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..677d64e4e 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,438 +65,447 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [ + name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [ + name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend( + ["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = [ + "FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = [ + "FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = [ + "Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, - ) + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..ac333d329 --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,50 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..81c832c3c --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,73 @@ +from absl import app +from typing import Sequence +import jax +import json +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, +) +from jax.sharding import Mesh, PartitionSpec as P + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + + +def run(config): + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) + + # load in model config + config_path = os.path.join( + base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + transformer = Transformer3DModel( + **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights( + key, batch_size, text_tokens, num_tokens, features, eval_only=False) + + key, split_key = jax.random.split(key) + weights_init_fn = functools.partial( + transformer.init_weights, + split_key, + batch_size, + text_tokens, + num_tokens, + features, + eval_only=False + ) + + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys From 13656fb457723f52e935e8afab69a983d5cfd68a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 20:32:05 +0000 Subject: [PATCH 02/69] ltx-video-transformer-setup --- src/maxdiffusion/configs/ltx_video.yml | 15 + src/maxdiffusion/generate_ltx_video.py | 29 +- src/maxdiffusion/models/__init__.py | 17 +- src/maxdiffusion/models/ltx_video/__init__.py | 0 .../models/ltx_video/gradient_checkpoint.py | 70 ++ src/maxdiffusion/models/ltx_video/linear.py | 111 ++ .../models/ltx_video/repeatable_layer.py | 105 ++ .../models/ltx_video/transformers/__init__.py | 0 .../ltx_video/transformers/activations.py | 176 ++++ .../models/ltx_video/transformers/adaln.py | 201 ++++ .../ltx_video/transformers/attention.py | 945 ++++++++++++++++++ .../transformers/caption_projection.py | 40 + .../ltx_video/transformers/transformer3d.py | 322 ++++++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 24 + 14 files changed, 2036 insertions(+), 19 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/gradient_checkpoint.py create mode 100644 src/maxdiffusion/models/ltx_video/linear.py create mode 100644 src/maxdiffusion/models/ltx_video/repeatable_layer.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/activations.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/adaln.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/caption_projection.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/transformer3d.py create mode 100644 src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index ac333d329..954922521 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,3 +1,18 @@ +# Copyright 2025 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + #hardware hardware: 'tpu' skip_jax_distributed_system: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 81c832c3c..6d96aa8c2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,20 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + from absl import app from typing import Sequence import jax @@ -50,17 +67,7 @@ def run(config): text_tokens, num_tokens, features, - eval_only=False - ) - - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, + eval_only=True ) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 96a6f1286..20c27ab20 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,14 +25,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..f32cc9459 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..fd92c695d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,111 @@ +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features):], + kernel_shape[-len(features):], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..882f21ace --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,105 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = { + nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( + self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn( + self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..3e1fd6d6e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,176 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from diffusers.utils.deprecation_utils import deprecate + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError( + f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..374af6acc --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,201 @@ +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint( + emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate( + [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint( + sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + + """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb( + timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..4ade671c7 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,945 @@ +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto + +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in [ + "single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name( + hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", + "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * \ + (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm( + hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any( + device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print( + "Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get( + "scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", + "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint( + hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint( + encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape( + hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states( + encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones( + (batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", + "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes( + hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + \ + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1]( + hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes( + hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape( + skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad( + attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), + kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError( + f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning( + self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype( + jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * \ + jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) + * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError( + f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..dff8b8c62 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,40 @@ +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..4368c35fb --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,322 @@ +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning( + scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm( + epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal( + split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal( + split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal( + split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal( + split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection( + encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, + :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", + "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, + axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..02f13b15a --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,24 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000 +} \ No newline at end of file From 7bed4f99d1b8c19eb1ee622ee895f1ca31b1b870 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 21:12:17 +0000 Subject: [PATCH 03/69] formatting --- src/maxdiffusion/__init__.py | 723 ++++--- src/maxdiffusion/generate_ltx_video.py | 60 +- src/maxdiffusion/models/__init__.py | 17 +- .../models/ltx_video/gradient_checkpoint.py | 102 +- src/maxdiffusion/models/ltx_video/linear.py | 170 +- .../models/ltx_video/repeatable_layer.py | 129 +- .../ltx_video/transformers/activations.py | 275 ++- .../models/ltx_video/transformers/adaln.py | 328 ++-- .../ltx_video/transformers/attention.py | 1696 ++++++++--------- .../transformers/caption_projection.py | 60 +- .../ltx_video/transformers/transformer3d.py | 585 +++--- 11 files changed, 2023 insertions(+), 2122 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 677d64e4e..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,447 +65,440 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [ - name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [ - name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - ["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [ - name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = [ - "FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = [ - "FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["models.ltx_video.transformers.transformer3d"] = [ - "Transformer3DModel"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.ltx_video.transformers.transformer3d import Transformer3DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6d96aa8c2..d05203f5c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -14,67 +14,55 @@ limitations under the License. """ - from absl import app from typing import Sequence import jax import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os -import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, - setup_initial_state, ) -from jax.sharding import Mesh, PartitionSpec as P def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + # load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 - base_dir = os.path.dirname(__file__) + transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) - # load in model config - config_path = os.path.join( - base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) + key, split_key = jax.random.split(key) - transformer = Transformer3DModel( - **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights( - key, batch_size, text_tokens, num_tokens, features, eval_only=False) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( - transformer.init_weights, - split_key, - batch_size, - text_tokens, - num_tokens, - features, - eval_only=True - ) + weights_init_fn = functools.partial( + transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + ) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 20c27ab20..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,15 +25,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index f32cc9459..ef8c530ba 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -8,63 +8,63 @@ class GradientCheckpointType(Enum): - """ - Defines the type of the gradient checkpoint we will have + """ + Defines the type of the gradient checkpoint we will have - NONE - means no gradient checkpoint - FULL - means full gradient checkpoint, wherever possible (minimum memory usage) - MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, - except for ones that involve batch dimension - that means that all attention and projection - layers will have gradient checkpoint, but not the backward with respect to the parameters - """ + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ - NONE = auto() - FULL = auto() - MATMUL_WITHOUT_BATCH = auto() + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() - @classmethod - def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": - """ - Constructs the gradient checkpoint type from a string + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string - Args: - s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. - Returns: - GradientCheckpointType: The policy that corresponds to the string - """ - if s is None: - s = "none" - return GradientCheckpointType[s.upper()] + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] - def to_jax_policy(self): - """ - Converts the gradient checkpoint type to a jax policy - """ - match self: - case GradientCheckpointType.NONE: - return SKIP_GRADIENT_CHECKPOINT_KEY - case GradientCheckpointType.FULL: - return None - case GradientCheckpointType.MATMUL_WITHOUT_BATCH: - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nn.Module) -> nn.Module: - """ - Applies a gradient checkpoint policy to a module - if no policy is needed, it will return the module as is + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is - Args: - module (nn.Module): the module to apply the policy to + Args: + module (nn.Module): the module to apply the policy to - Returns: - nn.Module: the module with the policy applied - """ - policy = self.to_jax_policy() - if policy == SKIP_GRADIENT_CHECKPOINT_KEY: - return module - return nn.remat( # pylint: disable=invalid-name - module, - prevent_cse=False, - policy=policy, - ) + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index fd92c695d..31b21cdd9 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -13,99 +13,97 @@ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) -NdInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] -KernelInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] class DenseGeneral(nn.Module): - """A linear transformation with flexible axes. - - Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - weight_dtype: the dtype of the weights (default: float32). - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - use_bias: whether to add bias in linear transformation. - bias_norm: whether to add normalization before adding bias. - quant: quantization config, defaults to None implying no quantization. + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. """ - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - weight_dtype: jnp.dtype = jnp.float32 - dtype: np.dtype = jnp.float32 - kernel_init: KernelInitializer = lecun_normal() - kernel_axes: Tuple[Optional[str], ...] = () - use_bias: bool = False - matmul_precision: str = "default" - - bias_init: Initializer = jax.nn.initializers.constant(0.0) - - @nn.compact - def __call__(self, inputs: jax.Array) -> jax.Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - - def compute_dot_general(inputs, kernel, axis, contract_ind): - """Computes a dot_general operation that may be quantized.""" - dot_general = jax.lax.dot_general - matmul_precision = jax.lax.Precision(self.matmul_precision) - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) - - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) - kernel = self.param( - "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - ) - kernel = jnp.asarray(kernel, self.dtype) - - contract_ind = tuple(range(0, len(axis))) - output = compute_dot_general(inputs, kernel, axis, contract_ind) - - if self.use_bias: - bias_axes, bias_shape = ( - self.kernel_axes[-len(features):], - kernel_shape[-len(features):], - ) - bias = self.param( - "bias", - nn.with_logical_partitioning(self.bias_init, bias_axes), - bias_shape, - self.weight_dtype, - ) - bias = jnp.asarray(bias, self.dtype) - - output += bias - return output + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + # kernel_in_axis = np.arange(len(axis)) + # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 882f21ace..aaed41048 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -7,99 +7,96 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - @nn.compact - def __call__(self, *args) -> Tuple[jax.Array, None]: - """ - jax carry-op format of block - assumes the input contains an input tensor to the block along with kwargs that might be send to the block - kwargs are assumed to have static role, while the input changes between cycles + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles - Returns: - Tuple[jax.Array, None]: Output tensor from the block - """ - mod = self.module(*self.module_init_args, **self.module_init_kwargs) - output = mod(*args) - return output, None + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = { - nn.PARTITION_NAME: self.pspec_name} - - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( - self.param_scan_axis) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, # Separate params per timestep - split_rngs={"params": True}, - in_axes=(nn.broadcast,) * (len(args) - 1), - length=self.num_layers, - **scan_kwargs, - ) - wrapped_function = scan_fn( - self.module, self.module_init_args, self.module_init_kwargs) - x, _ = wrapped_function(*args) - return x + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 3e1fd6d6e..4a78b48ea 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -22,155 +22,154 @@ @jax.jit def approximate_gelu(x: jax.Array) -> jax.Array: - """ - Computes Gaussian Error Linear Unit (GELU) activation function + """ + Computes Gaussian Error Linear Unit (GELU) activation function - Args: - x (jax.Array): The input tensor + Args: + x (jax.Array): The input tensor - jax.Array: The output tensor - """ - # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs - # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior - if x.dtype in (jax.numpy.float64,): - x = x.clip(-10, None) - return jax.nn.gelu(x, approximate=True) + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) def get_activation(act_fn: str): - """Returns the activation function from string.""" - act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] - raise ValueError(f"Unsupported activation function: {act_fn}") + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - approximate: str = "none" - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def gelu(self, gate: jax.Array) -> jax.Array: - approximate_to_tanh = self.approximate == "tanh" - if approximate_to_tanh: - return approximate_gelu(gate) - else: - return jax.nn.gelu(gate, approximate=False) - - @nn.compact - def __call__(self, hidden_states): - if self.approximate not in ("none", "tanh"): - raise ValueError( - f"approximate must be 'none' or 'tanh', got {self.approximate}") - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - hidden_states = proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states class GEGLU(nn.Module): - r""" - A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - proj = DenseGeneral( - features=self.dim_out * 2, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - - hidden_states = proj(hidden_states) - hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) - return hidden_states * jax.nn.gelu(gate, approximate=False) + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) class ApproximateGELU(nn.Module): - r""" - The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this - [paper](https://arxiv.org/abs/1606.08415). - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, x): - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - x = proj(x) - return x * jax.nn.sigmoid(1.702 * x) + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 374af6acc..4bc27e8bc 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -17,185 +17,177 @@ def get_timestep_embedding_multidim( scale: float = 1, max_period: int = 10000, ) -> jnp.ndarray: - """ - Computes sinusoidal timestep embeddings while preserving the original dimensions. - No reshaping to 1D is performed at any stage. - - Args: - timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. - embedding_dim (int): The dimension of the output. - flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) - or `sin, cos` (if False). - downscale_freq_shift (float): Controls the delta between frequencies between dimensions. - scale (float): Scaling factor applied to the embeddings. - max_period (int): Controls the maximum frequency of the embeddings. - - Returns: - jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. - """ - half_dim = embedding_dim // 2 - exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) - exponent = exponent / (half_dim - downscale_freq_shift) - shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) - emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape - emb = nn.with_logical_constraint( - emb, ("activation_batch", "activation_norm_length", "activation_embed")) - # Broadcasting to match shape (*timesteps.shape, half_dim) - emb = timesteps[..., None] * emb - emb = scale * emb - # Shape (*timesteps.shape, embedding_dim) - emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) - if flip_sin_to_cos: - emb = jnp.concatenate( - [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) - - return emb + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb class TimestepEmbedding(nn.Module): - in_channels: int - time_embed_dim: int - act_fn: str = "silu" - out_dim: Optional[int] = None - sample_proj_bias: bool = True - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers efficiently""" - self.linear_1 = DenseGeneral( - self.time_embed_dim, - use_bias=self.sample_proj_bias, - kernel_axes=(None, "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - ) - - self.act = get_activation(self.act_fn) - time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim - self.linear_2 = DenseGeneral( - time_embed_dim_out, - use_bias=self.sample_proj_bias, - kernel_axes=("embed", "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - ) - - def __call__(self, sample, condition=None): - sample = nn.with_logical_constraint( - sample, ("activation_batch", "activation_norm_length", "activation_embed")) - sample = self.linear_1(sample) - sample = self.act(sample) - sample = self.linear_2(sample) - return sample + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample class Timesteps(nn.Module): - num_channels: int - flip_sin_to_cos: bool - downscale_freq_shift: float - scale: int = 1 - - def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: - t_emb = get_timestep_embedding_multidim( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - return t_emb + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - """ - - """ - - embedding_dim: int - size_emb_dim: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize sub-modules.""" - self.outdim = self.size_emb_dim - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, - time_embed_dim=self.embedding_dim, - name="timestep_embedder", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.astype(hidden_dtype)) - return timesteps_emb + """ """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb class AdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ + Compute AdaLayerNorm-Single modulation. - embedding_dim: int - embedding_coefficient: int = 6 - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - self.emb = AlphaCombinedTimestepSizeEmbeddings( - self.embedding_dim, - size_emb_dim=self.embedding_dim // 3, - name="emb", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - self.silu = jax.nn.silu - self.linear = DenseGeneral( - self.embedding_coefficient * self.embedding_dim, - use_bias=True, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear", - ) - - def __call__( - self, - timestep: jnp.ndarray, - added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, - batch_size: Optional[int] = None, - hidden_dtype: Optional[jnp.dtype] = None, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Compute AdaLayerNorm-Single modulation. - - Returns: - Tuple: - - Processed embedding after SiLU + linear transformation. - - Original embedded timestep. - """ - embedded_timestep = self.emb( - timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) - return self.linear(self.silu(embedded_timestep)), embedded_timestep + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4ade671c7..5d12e7813 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -25,921 +25,869 @@ class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - def __call__(self, x): - return x + + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in [ - "single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, + name="to_out.0", matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print( - "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name( - hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] - - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", - "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) - ) - norm_hidden_states = norm_hidden_states * \ - (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - attn_output = nn.with_logical_constraint( - attn_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm( - hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint( - attn_input, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) +class AttentionOp(nn.Module): - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * \ - (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint( - ff_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) - - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="to_out.0", - matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any( - device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print( - "Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = { - k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - assert cross_attention_kwargs.get( - "scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", - "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint( - hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint( - encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape( - hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states( - encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") + ** SRAM cache size for TPU + V5P - 1MB SRAM per core - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones( - (batch_size, query.shape[2]), dtype=jnp.float32) - hidden_states_a = self.attention( - query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype - ) +class ExplicitAttention(nn.Module): - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", - "activation_length", "activation_kv") - ) + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v - hidden_states_a = jnp.reshape(jnp.swapaxes( - hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + \ - hidden_states * (1.0 - skip_layer_mask) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1]( - hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes( - hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape( - skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad( - attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. -class AttentionOp(nn.Module): - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), - kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError( - f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - None, - None, - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. - - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can - - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM - - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size - - ** SRAM cache size for TPU - V5P - 1MB SRAM per core - - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used - - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + Returns: + jax.Array: Normed data + """ -class ExplicitAttention(nn.Module): - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. - - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. - - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data - - Returns: - jax.Array: Normed data - """ - - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning( - self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype( - jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * \ - jax.lax.rsqrt(variance + self.epsilon) - - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) - * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError( - f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index dff8b8c62..f2b1af101 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -6,35 +6,35 @@ class CaptionProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - """ + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ - in_features: int - hidden_size: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" - @nn.compact - def __call__(self, caption): - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - )(caption) - hidden_states = approximate_gelu(hidden_states) - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - )(hidden_states) - return hidden_states + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 4368c35fb..dac8e6280 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -13,310 +13,297 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning( - scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm( - epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal( - split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal( - split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal( - split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal( - split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) - if eval_only: - return jax.eval_shape( - self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection( - encoder_hidden_states) - - if self.num_layers > 0: - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - ) - # Output processing - - scale_shift_values = ( - self.scale_shift_table[jnp.newaxis, jnp.newaxis, - :, :] + embedded_timestep[:, :, jnp.newaxis] - ) - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", - "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. + """ + Computes log of x with defined base. - Args: - x (jax.Array): log value - base (jax.Array): base of the log + Args: + x (jax.Array): log value + base (jax.Array): base of the log - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] - for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, - axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) From b31a97b4b1ba26c51c26ae142d9ffb356f9d2618 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 22:41:05 +0000 Subject: [PATCH 04/69] conversion script added --- .../utils/convert_torch_weights_to_jax.py | 257 +++++++++ .../utils/diffusers_config_mapping.py | 174 ++++++ .../ltx_video/utils/skip_layer_strategy.py | 8 + .../models/ltx_video/utils/torch_compat.py | 520 ++++++++++++++++++ 4 files changed, 959 insertions(+) create mode 100644 src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_compat.py diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py new file mode 100644 index 000000000..4e306c7ea --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -0,0 +1,257 @@ +import argparse +import json +from typing import Any, Dict, Optional + +import jax +import jax.numpy as jnp +from flax.training import train_state +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file + +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel +from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax + +from huggingface_hub import hf_hub_download +import os + + +class Checkpointer: + """ + Checkpointer - to load and store JAX checkpoints + """ + + STATE_DICT_SHAPE_KEY = "shape" + STATE_DICT_DTYPE_KEY = "dtype" + TRAIN_STATE_FILE_NAME = "train_state" + + def __init__( + self, + checkpoint_dir: str, + use_zarr3: bool = False, + save_buffer_size: Optional[int] = None, + restore_buffer_size: Optional[int] = None, + ): + """ + Constructs the checkpointer object + """ + opts = ocp.CheckpointManagerOptions( + enable_async_checkpointing=True, + step_format_fixed_length=8, # to make the format of "00000000" + ) + self.use_zarr3 = use_zarr3 + self.save_buffer_size = save_buffer_size + self.restore_buffer_size = restore_buffer_size + registry = ocp.DefaultCheckpointHandlerRegistry() + self.train_state_handler = ocp.PyTreeCheckpointHandler( + save_concurrent_gb=save_buffer_size, + restore_concurrent_gb=restore_buffer_size, + use_zarr3=use_zarr3, + ) + registry.add( + self.TRAIN_STATE_FILE_NAME, + ocp.args.PyTreeSave, + self.train_state_handler, + ) + self.manager = ocp.CheckpointManager( + directory=checkpoint_dir, + options=opts, + handler_registry=registry, + ) + + @property + def save_buffer_size_bytes(self) -> Optional[int]: + if self.save_buffer_size is None: + return None + return self.save_buffer_size * 2**30 + + @staticmethod + def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements. + With this, we can reconstruct the state_dict structure later on. + """ + return jax.tree_util.tree_map( + lambda t: { + Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape), + Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name, + }, + state_dict, + is_leaf=lambda t: isinstance(t, jax.Array), + ) + + def save( + self, + step: int, + state: train_state.TrainState, + config: Dict[str, Any], + ): + """ + Saves the checkpoint asynchronously + + NOTE that state is going to be copied for this operation + + Args: + step (int): The step of the checkpoint + state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state + config (Dict[str, Any]): A dictionary containing the configuration of the model + """ + self.wait() + args = ocp.args.Composite( + train_state=ocp.args.PyTreeSave( + state, + ocdbt_target_data_file_size=self.save_buffer_size_bytes, + ), + config=ocp.args.JsonSave(config), + meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)), + ) + self.manager.save( + step, + args=args, + ) + + def wait(self): + """ + Waits for the checkpoint save operation to complete + """ + self.manager.wait_until_finished() + + +""" +Convert Torch checkpoints to JAX. + +This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it. +""" + + +def main(args): + """ + Convert a Torch checkpoint into JAX. + """ + + if args.output_step_num > 1: + print( + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ) + + print("Loading safetensors, flush = True") + weight_file = "ltxv-13b-0.9.7-dev.safetensors" + + # download from huggingface, otherwise load from local + if args.local_ckpt_path is None: + print("Loading from HF", flush=True) + model_name = "Lightricks/LTX-Video" + local_file_path = hf_hub_download( + repo_id=model_name, + filename=weight_file, + local_dir=args.download_ckpt_path, + local_dir_use_symlinks=False, + ) + else: + base_dir = args.local_ckpt_path + local_file_path = os.path.join(base_dir, weight_file) + torch_state_dict = load_file(local_file_path) + + print("Initializing pytorch transformer..", flush=True) + transformer_config = json.loads(open(args.transformer_config_path, "r").read()) + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"] + for key in ignored_keys: + if key in transformer_config: + del transformer_config[key] + + transformer = Transformer3DModel_PT.from_config(transformer_config) + + print("Loading torch weights into transformer..", flush=True) + transformer.load_state_dict(torch_state_dict) + torch_state_dict = transformer.state_dict() + + print("Creating jax transformer with params..", flush=True) + transformer_config["use_tpu_flash_attention"] = True + in_channels = transformer_config["in_channels"] + del transformer_config["in_channels"] + jax_transformer3d = JaxTranformer3DModel( + **transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch" + ) + example_inputs = {} + batch_size, num_tokens = 2, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs) + + print("Converting torch params to jax..", flush=True) + params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) + + print("Creating checkpointer and jax state for saving..", flush=True) + relative_ckpt_path = args.output_dir + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + tx = optax.adamw(learning_rate=1e-5) + with jax.default_device("cpu"): + state = train_state.TrainState( + step=args.output_step_num, + apply_fn=jax_transformer3d.apply, + params=params_jax, + tx=tx, + opt_state=tx.init(params_jax), + ) + with ocp.CheckpointManager(absolute_ckpt_path) as mngr: + mngr.save(args.output_step_num, args=ocp.args.StandardSave(state.params)) + print("Done.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") + parser.add_argument( + "--local_ckpt_path", + type=str, + required=False, + help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", + ) + + parser.add_argument( + "--download_ckpt_path", + type=str, + required=False, + help="Location to download safetensors from huggingface", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'", + ) + parser.add_argument( + "--output_step_num", + default=1, + type=int, + required=False, + help=( + "The step number to assign to the output checkpoint. The result will be saved using this step value. " + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ), + ) + parser.add_argument( + "--transformer_config_path", + default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json", + type=str, + required=False, + help="Path to Transformer3D structure config to load the weights based on.", + ) + + args = parser.parse_args() + main(args) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..832bda051 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..476d38c75 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py new file mode 100644 index 000000000..475d4b9ef --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -0,0 +1,520 @@ +import re +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union, Any + +import flax +import jax +import torch +import torch.utils._pytree as pytree +from flax.traverse_util import flatten_dict + + +AnyTensor = Union[jax.Array, torch.Tensor] +StateDict = Dict[str, AnyTensor] + +ScanRepeatableCarryBlock = "ScanRepeatableCarryBlock" + +JaxParams = Dict[str, Union[Dict[str, jax.Array], jax.Array]] + + +def unbox_logically_partioned(statedict: JaxParams) -> JaxParams: + return jax.tree_util.tree_map( + lambda t: t.unbox() if isinstance(t, flax.linen.spmd.LogicallyPartitioned) else t, + statedict, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + + +def torch_tensor_to_jax_array(data: torch.Tensor) -> jax.Array: + match data.dtype: + case torch.bfloat16: + return jax.numpy.from_dlpack(data) + case _: + return jax.numpy.array(data) + + +def is_stack_or_tensor(param: Any) -> bool: + """ + Returns True if param is of type tensor or list/tuple of tensors (stack of tensors) + + Used for mapping utils + """ + return isinstance(param, (torch.Tensor, list, tuple)) + + +def convert_tensor_stack_to_tensor(param: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + """ + Converts a list of torch tensors to a single torch tensor. + Args: + param (Union[List[torch.Tensor], torch.Tensor]): The parameter to convert. + + Returns: + torch.Tensor: The converted tensor. + """ + if isinstance(param, list): + return torch.stack(param) + return param + + +@dataclass +class ConvertAction: + """ + Defines a set of actions to be done on a given parameter. + + The definition must be commutative, i.e. the order of the actions should not matter. + also we should strive for actions to be reversible (so the same action can be used for both directions). + """ + + transpose: Optional[Tuple[int, int]] = None + """ + If defined, transposes the tensor with the given indices. + Example: (1, 0) transposes a (at least 2D tensor) from (..., a, b) to (..., b, a). + """ + + rename: Optional[Dict[str, str]] = None + """ + If defined, renames the parameter according to the given mapping. + Example: {"torch": "weight", "jax": "kernel"} + * renames "torch.weight" to "jax.kernel" when converting from torch to jax. + * renames "jax.kernel" to "torch.weight" when converting from jax to torch. + """ + + split_by: Optional[str] = None + """ + If defined, splits the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.k1" assumes the parameter is a concatenation of multiple tensors (shaped: (n, ...)). + and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". + """ + + group_by: Optional[str] = None + """ + If defined, groups the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.1.k1", "ScanRepeatableCarryBlock.2.k1" + will be grouped into a single tensor named "ScanRepeatableCarryBlock.k1" shaped (n, ...). + + *** Note: + this is kind of the reverse of split_by, only a different behavior. + it's easy to define "actions" that are reversible in base of context (jax->torch, torch->jax). + but it's very wrong to do so, since it blocks modular behavior and makes the code harder to maintain. + + """ + + jax_groups: Optional[List[str]] = None + """ + Generally used in group_by, this is a list of all possible keys that can be used to group the parameters. + This must be defined if group_by is defined. + + It's due to the un-reversibility nature of the group_by action. + """ + + def apply_transpose(self, mini_statedict: StateDict) -> StateDict: + """ + Applies the transpose action if defined + Args: + mini_statedict (StateDict): Local context of the state dict + + Returns: + StateDict: Output local context of the state dict + """ + + if self.transpose is None: + return mini_statedict + index0, index1 = self.transpose + return {param_name: param.swapaxes(index0, index1) for param_name, param in mini_statedict.items()} + + def apply_rename(self, mini_statedict: StateDict, delim: str) -> StateDict: + """ + Applies the rename action if defined + + Args: + mini_statedict (StateDict): Local context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + StateDict: Output local context of the state dict + """ + if self.rename is None: + return mini_statedict + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + parts = param_name.split(delim) + rename_source = "torch" if isinstance(param, torch.Tensor) else "jax" + rename_target = "jax" if isinstance(param, torch.Tensor) else "torch" + source_name = self.rename[rename_source] + dest_name = self.rename[rename_target] + if source_name == param_name: + new_param_name = dest_name + else: + # There is always ```self.rename[rename_source]``` in parts + index = parts.index(self.rename[rename_source]) + parts[index] = self.rename[rename_target] + new_param_name = delim.join(parts) + mini_statedict[new_param_name] = param + + return mini_statedict + + def apply_split_by(self, mini_statedict: StateDict, new_params: List, delim: str) -> Tuple[StateDict, List[str]]: + """ + Applies the split_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + delim (str): Output local context of the state dict + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.split_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + parts = param_name.split(delim) + indices = [i for i, p in enumerate(parts) if self.split_by in p] + if len(indices) != 1: + raise ValueError(f"Expected exactly one split_by in param_name: {param_name}") + index = indices[0] + params = mini_statedict.pop(param_name) + for i, param in enumerate(params): + new_parts = parts[:index] + [f"{i}"] + parts[index + 2 :] + new_param_name = delim.join(new_parts) + mini_statedict[new_param_name] = param + new_params.append(new_param_name) + + return mini_statedict, new_params + + def apply_group_by( + self, mini_statedict: StateDict, new_params: List, full_statedict: StateDict, delim: str + ) -> Tuple[StateDict, List[str]]: + """ + Applies the group_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.group_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + jax_keywords = extract_scan_keywords(param_name, self.jax_groups, delim) + block_index = re.findall(r"\.\d+\.", param_name)[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + new_param_name = f"{prefix}.{delim.join(jax_keywords)}.{suffix}" + + if new_param_name not in full_statedict: + full_statedict[new_param_name] = [param] + else: + full_statedict[new_param_name] = full_statedict[new_param_name] + [param] + + return mini_statedict, new_params + + def __call__( + self, + mini_statedict: StateDict, + new_params: List, + full_statedict: StateDict, + delim: str, + ) -> Tuple[StateDict, List[str]]: + """ + Given a state dict, applies the transformations defined in the ConvertAction. + + Args: + mini_statedict (StateDict): Local context of the state dict + new_params (List): new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Updated local state dict and list of new keys to add to the global state dict. + """ + mini_statedict = self.apply_transpose(mini_statedict) + mini_statedict = self.apply_rename(mini_statedict, delim) + mini_statedict, new_params = self.apply_split_by(mini_statedict, new_params, delim) + mini_statedict, new_params = self.apply_group_by(mini_statedict, new_params, full_statedict, delim) + return mini_statedict, new_params + + +def is_kernel_2d(param_name: str, param: AnyTensor) -> bool: + """ + Checks if the parameter is a 2D kernel (weight) or not. + usually applies to linear layers or convolutions. + Args: + param_name (str): Name of the parameter + param (AnyTensor): The parameter itself (could be either jax or torch Tensor) + + Returns: + bool: True if the parameter is a weight for linear/convolutional layer or not. + """ + expected_name = "weight" if isinstance(param, torch.Tensor) else "kernel" + return expected_name in param_name and param.ndim == 2 + + +def is_scan_repeatable(param_name: str, _) -> bool: + """ + Checks if the parameter is a scan repeatable carry block parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scan repeatable carry block parameter or not. + """ + return ScanRepeatableCarryBlock in param_name + + +def is_scale_shift_table(param_name: str, _) -> bool: + """ + Checks if the parameter is a scale shift table parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scale shift table parameter or not. + """ + return "scale_shift_table" in param_name + + +def is_affine_scale_param(param_name: str, parameter: AnyTensor, jax_flattened_keys: List[str]) -> bool: + """ + Checks if the parameter is an affine scale parameter. + + Args: + param_name (str): Parameter name + parameter (AnyTensor): The parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + + + Returns: + bool: True if the parameter is an affine scale parameter or not. + """ + if isinstance(parameter, torch.Tensor): + return "weight" in param_name and parameter.ndim == 1 and param_name not in jax_flattened_keys + else: + return "scale" in param_name and parameter.ndim == 1 + + +def extract_scan_keywords(param_name: str, jax_flattened_keys: List[str], delim: str) -> Optional[Tuple[str, str]]: + """ + Extracts the keywords from the scan repeatable carry block parameter (if exists) + + If the parameter is a scan repeatable carry block, it will return the keywords that are used to group the parameters. + otherwise it will return None. + + Args: + param_name (str): Name of the parameter + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + Optional[Tuple[str, str]]: Tuple of the keywords used to group the parameters (or None if it is not a scan repeatable carry block) + """ + block_indices = re.findall(r"\.\d+\.", param_name) + + if len(block_indices) == 0: + return None + block_indices = [block_indices[0]] + block_index = block_indices[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + for flat_key in jax_flattened_keys: + if flat_key.startswith(prefix) and flat_key.endswith(suffix): + mid_layer = flat_key[len(prefix) + 1 : -len(suffix) - 1] + mid_parts = mid_layer.split(delim) + if not any(ScanRepeatableCarryBlock in mid_part for mid_part in mid_parts): + continue + return mid_parts + + return None + + +def should_be_scan_repeatable(param_name: str, param: AnyTensor, jax_flattened_keys: List[str], delim: str) -> bool: + """ + Checks if the parameter should be a scan repeatable carry block or not. + Args: + param_name (str): The name of the parameter + param (AnyTensor): the Parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + bool: True if the paramter should be treated scan repeatable block parameter. + """ + if not isinstance(param, torch.Tensor): + return False + + keywords = extract_scan_keywords(param_name, jax_flattened_keys, delim) + return keywords is not None + + +def jax_statedict_to_torch( + jax_params: JaxParams, rulebook: Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]] = None +) -> Dict[str, torch.Tensor]: + """ + Converts a JAX state dict to a torch state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + rulebook (Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]], optional): Defines a rulebook stating how to convert state dict from jax to torch. + Defaults to None. + + + Returns: + Dict[str, torch.Tensor]: The converted state dict in torch format (Pytorch state dict). + """ + + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=[]) + + if rulebook is None: + rulebook = { + is_scan_repeatable: ConvertAction(split_by=ScanRepeatableCarryBlock), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + } + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + jax_params = copy(jax_params["params"]) # Non reference copy + jax_params = unbox_logically_partioned(jax_params) + + delim = "." + # Move to flattened dict to match torch state dict convention + flattened_params = flatten_dict(jax_params, sep=delim) + + param_names = list(flattened_params.keys()) + for param_name in param_names: + param = flattened_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, flattened_params, delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + flattened_params.update(mini_statedict) + param_names.extend(new_params) + + flattened_params = pytree.tree_map(convert_tensor_stack_to_tensor, flattened_params, is_leaf=is_stack_or_tensor) + + to_cpu = pytree.tree_map(lambda t: jax.device_put(t, jax.devices("cpu")[0]), flattened_params) + to_torch = pytree.tree_map(torch.from_dlpack, to_cpu) + return to_torch + + +def torch_statedict_to_jax( + jax_params: JaxParams, + torch_params: Dict[str, torch.Tensor], +) -> JaxParams: + """ + Converts a torch state dict to a JAX state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + torch_params (Dict[str, torch.Tensor]): The current params in torch format, to load parameters from. + + Returns: + JaxParams: The state dict in JAX format. + """ + with jax.default_device("cpu"): + jax_params = copy(jax_params) + jax_params = unbox_logically_partioned(jax_params) + torch_params = copy(torch_params) + + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + delim = "." + flattened_keys = list(flatten_dict(jax_params["params"], sep=".").keys()) + scan_repeatable_cond = partial(should_be_scan_repeatable, jax_flattened_keys=flattened_keys, delim=delim) + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=flattened_keys) + + rulebook = { + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + scan_repeatable_cond: ConvertAction(group_by=ScanRepeatableCarryBlock, jax_groups=flattened_keys), + } + + # First pass - Rulebook + param_names = list(torch_params.keys()) + for param_name in param_names: + param = torch_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, torch_params, delim=delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + torch_params.update(mini_statedict) + param_names.extend(new_params) + + # Ensures any list of tensors are converted to a single tensor + # This is due to the fact that the scan repeatable block is a list of tensors + torch_params = pytree.tree_map(convert_tensor_stack_to_tensor, torch_params, is_leaf=is_stack_or_tensor) + + to_jax: Dict = pytree.tree_map(torch_tensor_to_jax_array, torch_params) + + def nested_insert(param_name: str, param: torch.Tensor, nested_dict: Dict): + """ + Inserts a parameter into a nested dictionary. (to fit Jax format) + The keys in torch are split into groups by a delimiter of choice (usually "." to fit torch schema) + and then inserted into a nested dictionary. + + in case the parameter is of the form of "a.b" and "a.b" is a layer type in jax - + the parameter will be inserted as "a.b": {...: param}. this ensures compatibility between jax layers and torch layers. + + Args: + param_name (str): Parameter name + param (torch.Tensor): Parameter itself + nested_dict (Dict): Current nested dict state + """ + if delim not in param_name: + nested_dict[param_name] = param + return + + parts = param_name.split(delim) + if len(parts) == 1: + return nested_insert(parts[0], param, nested_dict) + else: + key = parts[0] + # May be either complex key or nested key + if len(parts) > 2 and re.fullmatch(r"\d+", parts[1]) is not None: + key = delim.join(parts[:2]) + new_param_name = delim.join(parts[2:]) + else: + new_param_name = delim.join(parts[1:]) + new_nested_dict = nested_dict.get(key, {}) + nested_dict[key] = new_nested_dict + return nested_insert(new_param_name, param, new_nested_dict) + + params = {} + for param_name, param in to_jax.items(): + nested_insert(param_name, param, params) + + # Jax state dict is usually held as dict containings "parmas" keys which contains + # dict of dict containing all the params + return {"params": params} From 7e098c586fad8874fa0d62912a42a11f159c9545 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 22:56:13 +0000 Subject: [PATCH 05/69] format fixed --- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 10 +-- .../ltx_video/transformers/attention.py | 2 +- .../ltx_video/transformers/transformer3d.py | 68 +++++++------------ .../ltx_video/xora_v1.2-13B-balanced-128.json | 3 +- 5 files changed, 33 insertions(+), 52 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 954922521..eb44d253b 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -62,4 +62,4 @@ cache_latents_text_encoder_outputs: True per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 -jit_initializers: True \ No newline at end of file +jit_initializers: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index d05203f5c..6efe564b2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,11 +20,13 @@ import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os +import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, ) +from jax.sharding import Mesh def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): @@ -38,7 +40,7 @@ def run(config): key = jax.random.PRNGKey(0) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) @@ -49,12 +51,10 @@ def run(config): model_config = json.load(f) transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 key, split_key = jax.random.split(key) - - - weights_init_fn = functools.partial( + weights_init_fn = functools.partial( # noqa F841 transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True ) diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 5d12e7813..4812b89ba 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -438,7 +438,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index dac8e6280..cf599f26c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module): only_cross_attention: bool = False double_self_attention: bool = False upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-5 attention_type: str = "default" caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') qk_norm: Optional[str] = None positional_embedding_type: str = "rope" positional_embedding_theta: Optional[float] = None @@ -98,7 +96,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( + module_init_kwargs=dict( # noqa C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, @@ -139,46 +137,30 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + if eval_only: return jax.eval_shape( self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, + key, + **example_inputs, )["params"] else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] + return self.init(key, **example_inputs)["params"] def __call__( self, @@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: @nn.compact def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 + dtype = jnp.float32 # We need full precision in the freqs_cis computation. dim = self.inner_dim theta = self.positional_embedding_theta @@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: indices = indices * jnp.pi / 2 freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 cos_freq = jnp.cos(freqs).repeat(2, axis=-1) sin_freq = jnp.sin(freqs).repeat(2, axis=-1) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 02f13b15a..75b16b011 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -20,5 +20,6 @@ "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000 + "timestep_scale_multiplier": 1000, + "in_channels": 128 } \ No newline at end of file From 9a9f5db89fce77b373b378ec11325f0a49966169 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 00:18:23 +0000 Subject: [PATCH 06/69] conversion script checked --- .../checkpointing/checkpointing_utils.py | 5 +- src/maxdiffusion/configs/ltx_video.yml | 13 ++- src/maxdiffusion/generate_ltx_video.py | 73 +++++++++++-- src/maxdiffusion/max_utils.py | 102 +++++++++++++----- .../ltx_video/transformers/attention.py | 22 ++-- .../ltx_video/transformers/transformer3d.py | 68 +++++------- .../models/ltx_video/utils/torch_compat.py | 16 +-- .../ltx_video/xora_v1.2-13B-balanced-128.json | 4 +- src/maxdiffusion/pyconfig.py | 16 +++ 9 files changed, 218 insertions(+), 101 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..6661bad83 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,7 +213,10 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 954922521..034aea65c 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -27,7 +27,7 @@ output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', ['data','fsdp']], @@ -40,7 +40,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 @@ -48,6 +48,12 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 + @@ -62,4 +68,5 @@ cache_latents_text_encoder_outputs: True per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 -jit_initializers: True \ No newline at end of file +jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index d05203f5c..e0b601ee8 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,43 +20,90 @@ import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os +import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, + setup_initial_state, + get_memory_allocations, ) +from jax.sharding import Mesh +import orbax.checkpoint as ocp -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) print("latents.shape: ", latents.shape, latents.dtype) print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) def run(config): - key = jax.random.PRNGKey(0) + + key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) - # load in model config + ##load in model config config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) + ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + transformer_param_shapes = transformer.init_weights( # noqa: F841 + in_channels, key, model_config["caption_channels"], eval_only=True + ) # use this to test! - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) - key, split_key = jax.random.split(key) + checkpoint_manager = ocp.CheckpointManager(ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() - weights_init_fn = functools.partial( - transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True - ) + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state def main(argv: Sequence[str]) -> None: @@ -66,3 +113,9 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": app.run(main) + + +###setup_initial_state, can optionally load from checkpoint + + +# end to end steps from ltx repo: pipeline_ltx_video.py diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..86323bf79 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -252,45 +252,88 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ return parallelism_vals -def create_device_mesh(config, devices=None, logging=True): +def create_device_mesh(config, devices=None): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) - try: - num_slices = 1 + max([d.slice_index for d in devices]) - except: - num_slices = 1 + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices - max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") - multi_slice_env = num_slices > 1 - - dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, - ] - ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, - ] + # multi_slice_env = num_slices > 1 # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") - if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) - else: - mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) - - if logging: - max_logging.log(f"Decided on mesh: {mesh}") + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + + # allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False + + # if allow_split_physical_axes: + # if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh): + # mesh = mesh_utils.create_device_mesh( + # [16, 16], + # devices, + # contiguous_submeshes=False, + # allow_split_physical_axes=False, + # ) + # mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh) + # mesh = np.reshape(mesh, ici_parallelism) + # else: + # mesh = mesh_utils.create_device_mesh( + # ici_parallelism, + # devices, + # contiguous_submeshes=False, + # allow_split_physical_axes=allow_split_physical_axes, + # ) + # else: + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh +# def create_device_mesh(config, devices=None, logging=True): +# """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" +# if devices is None: +# devices = jax.devices() +# num_devices = len(devices) +# try: +# num_slices = 1 + max([d.slice_index for d in devices]) +# except: +# num_slices = 1 +# num_devices_per_slice = num_devices // num_slices +# max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") + +# multi_slice_env = num_slices > 1 + +# dcn_parallelism = [ +# config.dcn_data_parallelism, +# config.dcn_fsdp_parallelism, +# config.dcn_tensor_parallelism, +# ] +# ici_parallelism = [ +# config.ici_data_parallelism, +# config.ici_fsdp_parallelism, +# config.ici_tensor_parallelism, +# ] + +# # Find possible unspecified parallelisms +# ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") +# if multi_slice_env: +# dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") +# mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) +# else: +# mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + +# if logging: +# max_logging.log(f"Decided on mesh: {mesh}") + +# return mesh + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -402,7 +445,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 5d12e7813..b9185825f 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -2,7 +2,6 @@ import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto - import jax import jax.nn as jnn import jax.numpy as jnp @@ -198,8 +197,7 @@ def __call__( # Adaptive Norm if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] num_ada_params = self.scale_shift_table.shape[0] ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( batch_size, timestep.shape[1], num_ada_params, -1 @@ -438,7 +436,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") @@ -628,8 +626,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): None, None, ) + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # None, + # None, + # None, + # None, + # ) # Based on: ("activation_kv_batch", "activation_length") qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, @@ -814,8 +825,7 @@ def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: inner_dim = dim * self.mult if inner_dim < 256: raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 else: inner_dim = self.inner_dim diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index dac8e6280..9466d1f84 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module): only_cross_attention: bool = False double_self_attention: bool = False upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-5 attention_type: str = "default" caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') qk_norm: Optional[str] = None positional_embedding_type: str = "rope" positional_embedding_theta: Optional[float] = None @@ -98,7 +96,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( + module_init_kwargs=dict( #noqa: C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, @@ -139,46 +137,30 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + if eval_only: return jax.eval_shape( self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, + key, + **example_inputs, )["params"] else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] + return self.init(key, **example_inputs)["params"] def __call__( self, @@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: @nn.compact def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 + dtype = jnp.float32 # We need full precision in the freqs_cis computation. dim = self.inner_dim theta = self.positional_embedding_theta @@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: indices = indices * jnp.pi / 2 freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 cos_freq = jnp.cos(freqs).repeat(2, axis=-1) sin_freq = jnp.sin(freqs).repeat(2, axis=-1) diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py index 475d4b9ef..23fa871d6 100644 --- a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -85,7 +85,7 @@ class ConvertAction: """ If defined, splits the parameter by the given delimiter. Example: "ScanRepeatableCarryBlock.k1" assumes the parameter is a concatenation of multiple tensors (shaped: (n, ...)). - and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". + and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". """ group_by: Optional[str] = None @@ -93,19 +93,19 @@ class ConvertAction: If defined, groups the parameter by the given delimiter. Example: "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.1.k1", "ScanRepeatableCarryBlock.2.k1" will be grouped into a single tensor named "ScanRepeatableCarryBlock.k1" shaped (n, ...). - + *** Note: this is kind of the reverse of split_by, only a different behavior. it's easy to define "actions" that are reversible in base of context (jax->torch, torch->jax). but it's very wrong to do so, since it blocks modular behavior and makes the code harder to maintain. - + """ jax_groups: Optional[List[str]] = None """ Generally used in group_by, this is a list of all possible keys that can be used to group the parameters. This must be defined if group_by is defined. - + It's due to the un-reversibility nature of the group_by action. """ @@ -390,8 +390,8 @@ def jax_statedict_to_torch( if rulebook is None: rulebook = { is_scan_repeatable: ConvertAction(split_by=ScanRepeatableCarryBlock), - is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), - affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), # noqa C408 + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), # noqa C408 } if "params" not in jax_params: raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') @@ -452,8 +452,8 @@ def torch_statedict_to_jax( affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=flattened_keys) rulebook = { - is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), - affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), # noqa C408 + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), # noqa C408 scan_repeatable_cond: ConvertAction(group_by=ScanRepeatableCarryBlock, jax_groups=flattened_keys), } diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 02f13b15a..c5b3c0ef9 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,4 +1,5 @@ { + "ckpt_path": "", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, @@ -20,5 +21,6 @@ "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000 + "timestep_scale_multiplier": 1000, + "in_channels": 128 } \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 67437ba0b..fe4152240 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -154,6 +169,7 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From d1c304da1a6737b7846f17c8f78a82b2fe44895b Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 00:20:19 +0000 Subject: [PATCH 07/69] comments removed --- src/maxdiffusion/generate_ltx_video.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index e0b601ee8..d0ad099e9 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -77,8 +77,7 @@ def run(config): ) transformer_param_shapes = transformer.init_weights( # noqa: F841 in_channels, key, model_config["caption_channels"], eval_only=True - ) # use this to test! - + ) weights_init_fn = functools.partial( transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True ) @@ -115,7 +114,4 @@ def main(argv: Sequence[str]) -> None: app.run(main) -###setup_initial_state, can optionally load from checkpoint - -# end to end steps from ltx repo: pipeline_ltx_video.py From f93c3bd5c760db73a3a96601e604432105a49df2 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 00:30:17 +0000 Subject: [PATCH 08/69] Added running instructions --- .../utils/conversion_script_instruction.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md diff --git a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md new file mode 100644 index 000000000..399f64e81 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md @@ -0,0 +1,13 @@ +### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions: +1. Create new tansformers_pytorch folder under models/ltx_video. +2. Move files attention.py, embeddings.py, symmetric_patchifier.py, transformer3d.py into the newly created folder. +3. Rename transformer3d.py to transformer_pt.py to distinguish from the pytorch version. Change classname to Transformer3DModel_PT. Also change classname in line "transformer = Transformer3DModel.from_config(transformer_config)" +4. Weight Downloading and Conversion + - If first time running (no local safetensors): \ + In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json. + - If already have local pytorch checkpoint: \ + Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location +5. Restoring Jax Weights into transformer: + - Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path. + - Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder. + From e0327e5ee32c6744ac3b6f669540fa2fede340ae Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 01:39:38 +0000 Subject: [PATCH 09/69] edited instruction --- .../models/ltx_video/utils/conversion_script_instruction.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md index 399f64e81..3e571f7b7 100644 --- a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md +++ b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md @@ -1,6 +1,6 @@ ### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions: 1. Create new tansformers_pytorch folder under models/ltx_video. -2. Move files attention.py, embeddings.py, symmetric_patchifier.py, transformer3d.py into the newly created folder. +2. Move files from LTX repo, specifically, attention.py, embeddings.py, symmetric_patchifier.py, and transformer3d.py into the newly created folder. See here: https://github.com/Lightricks/LTX-Video/tree/main/ltx_video/models/transformers 3. Rename transformer3d.py to transformer_pt.py to distinguish from the pytorch version. Change classname to Transformer3DModel_PT. Also change classname in line "transformer = Transformer3DModel.from_config(transformer_config)" 4. Weight Downloading and Conversion - If first time running (no local safetensors): \ From c3693027792dac97255d842236a58a3794c1a529 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 19:17:13 +0000 Subject: [PATCH 10/69] ruff check error fixed --- src/maxdiffusion/generate_ltx_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index d0ad099e9..9ed2c92f7 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -77,7 +77,7 @@ def run(config): ) transformer_param_shapes = transformer.init_weights( # noqa: F841 in_channels, key, model_config["caption_channels"], eval_only=True - ) + ) weights_init_fn = functools.partial( transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True ) From 991a44ec1c1f7cda209fabebee7db3f4c241ccd2 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 20:19:59 +0000 Subject: [PATCH 11/69] mesh edit --- src/maxdiffusion/max_utils.py | 168 ++++++++++++++++++++-------------- src/maxdiffusion/pyconfig.py | 3 +- 2 files changed, 101 insertions(+), 70 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 86323bf79..6ad77df92 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -252,86 +252,115 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ return parallelism_vals -def create_device_mesh(config, devices=None): +def create_device_mesh(config, devices=None, logging=True): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices + ##special case for ltx-video + if config.ici_fsdp_transpose_parallelism: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + + try: + num_slices = 1 + max([d.slice_index for d in devices]) + except: + num_slices = 1 num_devices_per_slice = num_devices // num_slices + max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") + + multi_slice_env = num_slices > 1 - # multi_slice_env = num_slices > 1 + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - - # allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False - - # if allow_split_physical_axes: - # if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh): - # mesh = mesh_utils.create_device_mesh( - # [16, 16], - # devices, - # contiguous_submeshes=False, - # allow_split_physical_axes=False, - # ) - # mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh) - # mesh = np.reshape(mesh, ici_parallelism) - # else: - # mesh = mesh_utils.create_device_mesh( - # ici_parallelism, - # devices, - # contiguous_submeshes=False, - # allow_split_physical_axes=allow_split_physical_axes, - # ) - # else: - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") + if multi_slice_env: + dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") + mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + else: + mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + + if logging: + max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh -# def create_device_mesh(config, devices=None, logging=True): -# """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" -# if devices is None: -# devices = jax.devices() -# num_devices = len(devices) -# try: -# num_slices = 1 + max([d.slice_index for d in devices]) -# except: -# num_slices = 1 -# num_devices_per_slice = num_devices // num_slices -# max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") - -# multi_slice_env = num_slices > 1 - -# dcn_parallelism = [ -# config.dcn_data_parallelism, -# config.dcn_fsdp_parallelism, -# config.dcn_tensor_parallelism, -# ] -# ici_parallelism = [ -# config.ici_data_parallelism, -# config.ici_fsdp_parallelism, -# config.ici_tensor_parallelism, -# ] - -# # Find possible unspecified parallelisms -# ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") -# if multi_slice_env: -# dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") -# mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) -# else: -# mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) - -# if logging: -# max_logging.log(f"Decided on mesh: {mesh}") - -# return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): @@ -445,6 +474,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: + ###!Edited if checkpoint_item == " ": state = state else: @@ -655,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index fe4152240..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -169,7 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - raw_keys = create_parallelisms_list(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From b0e9bab09045d89a0196ee44a07c05387d2524b8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 27 Jun 2025 20:26:17 +0000 Subject: [PATCH 12/69] key error fix --- src/maxdiffusion/max_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6ad77df92..c48b7da0f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True): devices = jax.devices() num_devices = len(devices) ##special case for ltx-video - if config.ici_fsdp_transpose_parallelism: + if "fsdp_transpose" in config.mesh_axes: num_slices = 1 # if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices From e18128c3f19db8a8cba15195397281addaf0558a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 18:17:48 +0000 Subject: [PATCH 13/69] transformer step and test --- src/maxdiffusion/configs/ltx_video.yml | 24 ++- src/maxdiffusion/generate_ltx_video.py | 171 ++++++++++++--- src/maxdiffusion/max_utils.py | 21 +- .../ltx_video/xora_v1.2-13B-balanced-128.json | 1 + src/maxdiffusion/pyconfig.py | 17 ++ .../tests/ltx_transformer_step_test.py | 198 ++++++++++++++++++ .../tests/ltx_vid_transformer_test_ref_pred | Bin 0 -> 263834 bytes 7 files changed, 402 insertions(+), 30 deletions(-) create mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py create mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index eb44d253b..8f1ee8a7d 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,12 +22,25 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', ['data','fsdp']], @@ -40,13 +53,19 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 + ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 @@ -63,3 +82,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6efe564b2..bad791cee 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,23 +1,8 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - from absl import app from typing import Sequence import jax import json +from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os import functools @@ -25,39 +10,171 @@ from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, + setup_initial_state, + get_memory_allocations, ) -from jax.sharding import Mesh +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) print("latents.shape: ", latents.shape, latents.dtype) print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return noise_pred def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 + mesh = Mesh(devices_array, config.mesh_axes) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) - # load in model config + ##load in model config config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( # noqa F841 - transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + + # create dummy inputs: + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids + ) + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, ) + noise_pred = p_run_inference(states).block_until_ready() + print(noise_pred) # (4, 256, 128) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..f3f5148b2 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if config.ici_fsdp_transpose_parallelism: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -402,7 +417,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 75b16b011..c5b3c0ef9 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,4 +1,5 @@ { + "ckpt_path": "", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 67437ba0b..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -154,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..b0a266b70 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,198 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 0000000000000000000000000000000000000000..0a9fe912036cf35e35d5d8127cff178e1b4a9399 GIT binary patch literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C literal 0 HcmV?d00001 From 1c554523fbeddc9d3f85f59ed2e23c36a32356be Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:24:26 +0000 Subject: [PATCH 14/69] removed diffusers import --- .../models/ltx_video/transformers/activations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 4a78b48ea..8e7ffb321 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -5,8 +5,6 @@ from flax import linen as nn from flax.linen.initializers import lecun_normal -from diffusers.utils.deprecation_utils import deprecate - from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer @@ -117,9 +115,9 @@ class GEGLU(nn.Module): @nn.compact def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) proj = DenseGeneral( features=self.dim_out * 2, From fd4af91ddb23b73dd7fb58958f49fb24bd938db2 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:30:45 +0000 Subject: [PATCH 15/69] fixed mesh --- src/maxdiffusion/max_utils.py | 63 +++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index f3f5148b2..c48b7da0f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True): devices = jax.devices() num_devices = len(devices) ##special case for ltx-video - if config.ici_fsdp_transpose_parallelism: + if "fsdp_transpose" in config.mesh_axes: num_slices = 1 # if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -628,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file From 5e17a62648b1d5d26198b3b1adea6cb89e0eb904 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 01:05:14 +0000 Subject: [PATCH 16/69] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index b0a266b70..d0f6c2e1d 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,7 @@ def load_ref_prediction(): - saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From fc60b27b30e2988a4e4fad2d0252a0f044ded270 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 03:44:05 +0000 Subject: [PATCH 17/69] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d0f6c2e1d..d40c932ba 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,8 @@ def load_ref_prediction(): - saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From 3243535216e0348c36ddd7c3e90316c66343026d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:33:16 +0000 Subject: [PATCH 18/69] changed config path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d40c932ba..43b15ba72 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -103,7 +103,9 @@ def test_one_step_transformer(self): key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From e873a17c795037318b46222c86137bb7899bb43d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:44:51 +0000 Subject: [PATCH 19/69] ruff check --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 43b15ba72..61c6909c0 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -105,7 +105,7 @@ def test_one_step_transformer(self): mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From d06dee301231259fa83f3e7849d4af0f8655d8bf Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 17:55:48 +0000 Subject: [PATCH 20/69] changed back pyconfig --- src/maxdiffusion/max_utils.py | 78 +---------------------------------- src/maxdiffusion/pyconfig.py | 41 ++++++++++-------- 2 files changed, 24 insertions(+), 95 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +288,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -474,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..f4e1900aa 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): @@ -221,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) + r = range(config.steps) \ No newline at end of file From aa7befd137cbb9bd28db098d23af7ab75de37b5d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 21:38:05 +0000 Subject: [PATCH 21/69] changed sharding back --- src/maxdiffusion/configs/ltx_video.yml | 8 +++++--- src/maxdiffusion/max_utils.py | 2 +- .../models/ltx_video/transformers/attention.py | 13 ++++++++++--- src/maxdiffusion/pyconfig.py | 2 +- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 8f1ee8a7d..d29707537 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -40,20 +40,22 @@ output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..fab895f97 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -609,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4812b89ba..e9d9d932d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -622,14 +622,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + ("data", "fsdp", "tensor"), None, None, ) # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("fsdp", None) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index f4e1900aa..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -226,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) \ No newline at end of file + r = range(config.steps) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 61c6909c0..9a816d6e5 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -104,7 +104,7 @@ def test_one_step_transformer(self): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) From d9a35020c13df74bd91ac3e74a111d8a1e82673d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sat, 5 Jul 2025 23:07:45 +0000 Subject: [PATCH 22/69] removed testing for now --- src/maxdiffusion/max_utils.py | 61 +----- .../tests/ltx_transformer_step_test.py | 201 ------------------ .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 3 files changed, 2 insertions(+), 260 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..d4a80a347 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 61c6909c0..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From a1ad421eda51fa42d34bc4667ff32f107d9d36e8 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:25 -0700 Subject: [PATCH 23/69] Update pyconfig.py --- src/maxdiffusion/pyconfig.py | 39 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index edcf96164..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,7 +25,6 @@ import yaml from . import max_logging from . import max_utils -from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -42,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -103,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) - _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -112,26 +125,6 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict - @staticmethod - def wan_init(raw_keys): - if "wan_transformer_pretrained_model_name_or_path" in raw_keys: - transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] - if transformer_pretrained_model_name_or_path == "": - raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif ( - transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH - or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH - ): - # Set correct parameters for CausVid in case of user error. - raw_keys["guidance_scale"] = 1.0 - num_inference_steps = raw_keys["num_inference_steps"] - if num_inference_steps > 10: - max_logging.log( - f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." - ) - else: - raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") - @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -176,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From 615174f94f66ebd5e19c98d1cac2b0bf242b70de Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:57 -0700 Subject: [PATCH 24/69] Update max_utils.py --- src/maxdiffusion/max_utils.py | 78 ++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..51de312ca 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if "fsdp_transpose" in config.mesh_axes: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -402,7 +474,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( From 7469c62c97f9f711114111490f9220a433b18d8a Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:26 -0700 Subject: [PATCH 25/69] Update ltx_video.yml --- src/maxdiffusion/configs/ltx_video.yml | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index d29707537..87d0e9bea 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,47 +22,32 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' -run_name: '' -output_dir: 'ltx-video-output' -save_config_to_gcs: False - -#hardware -hardware: 'tpu' -skip_jax_distributed_system: False - -jax_cache_dir: '' -weights_dtype: 'bfloat16' -activations_dtype: 'bfloat16' - - run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 - ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 + ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_tensor_transpose_parallelism: 1 @@ -84,4 +69,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True -enable_single_replica_ckpt_restoring: False \ No newline at end of file +enable_single_replica_ckpt_restoring: False From 6de4424170129d21fb4c11b38a62cccfc8e9a0d2 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:51 -0700 Subject: [PATCH 26/69] Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred --- .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From 18ec247aa0b181f1ded5642093027d1cce109b3e Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:17:01 -0700 Subject: [PATCH 27/69] Delete src/maxdiffusion/tests/ltx_transformer_step_test.py --- .../tests/ltx_transformer_step_test.py | 201 ------------------ 1 file changed, 201 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 9a816d6e5..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() From 2737877eacc8be8400abff8b3bb68a4dccce888f Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 8 Jul 2025 22:56:14 +0000 Subject: [PATCH 28/69] added header --- .../models/ltx_video/gradient_checkpoint.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/linear.py | 16 ++++++++++++++++ .../models/ltx_video/repeatable_layer.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/activations.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/adaln.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/attention.py | 16 ++++++++++++++++ .../ltx_video/transformers/caption_projection.py | 16 ++++++++++++++++ .../ltx_video/transformers/transformer3d.py | 16 ++++++++++++++++ 8 files changed, 128 insertions(+) diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index ef8c530ba..ee7221652 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto from typing import Optional diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index 31b21cdd9..3503ab3b4 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Union, Iterable, Tuple, Optional, Callable import numpy as np diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index aaed41048..7e9cc80c4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 4a78b48ea..c97a6874e 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 4bc27e8bc..e9b287649 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Dict, Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index b9185825f..b5e9e030e 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from functools import partial import math from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index f2b1af101..d8240989c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from flax import linen as nn import jax.numpy as jnp diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 9466d1f84..2d82e328c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import List, Optional, Tuple import jax From 546ecab301c4e321d78bb2464a0df936c55beecb Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 00:20:43 +0000 Subject: [PATCH 29/69] ruff fixed --- src/maxdiffusion/generate_ltx_video.py | 8 +--- src/maxdiffusion/max_utils.py | 23 +---------- src/maxdiffusion/pyconfig.py | 39 +++++++++++-------- .../tests/ltx_transformer_step_test.py | 2 +- 4 files changed, 27 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 371d309e3..fa495ba1a 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -64,10 +64,6 @@ def run_inference( segment_ids=segment_ids, encoder_attention_segment_ids=encoder_attention_segment_ids, ) - prof = profiler.Profiler(config) - prof.activate(optional_postfix="transformer step") - prof.deactivate() - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) @@ -176,8 +172,8 @@ def run(config): in_shardings=(state_shardings,), out_shardings=None, ) - with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): - noise_pred = p_run_inference(states).block_until_ready() + + noise_pred = p_run_inference(states).block_until_ready() print(noise_pred) # (4, 256, 128) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index d4a80a347..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -417,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 2555b330c..9398c9156 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -191,7 +191,7 @@ def test_one_step_transformer(self): in_shardings=(state_shardings,), out_shardings=None, ) - + noise_pred = p_run_inference(states).block_until_ready() noise_pred = torch.from_numpy(np.array(noise_pred)) From 12a247fe9fa71af862a69644ea4275ec7c7d791d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 17:39:09 +0000 Subject: [PATCH 30/69] added header --- .../models/ltx_video/transformers/activations.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/adaln.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/attention.py | 16 ++++++++++++++++ .../ltx_video/transformers/caption_projection.py | 16 ++++++++++++++++ .../ltx_video/transformers/transformer3d.py | 16 ++++++++++++++++ 5 files changed, 80 insertions(+) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 8e7ffb321..4ae1d9a00 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 4bc27e8bc..e9b287649 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Dict, Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 8a7541263..9faab1ded 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from functools import partial import math from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index f2b1af101..d8240989c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from flax import linen as nn import jax.numpy as jnp diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index cf599f26c..d6c7cf4c4 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import List, Optional, Tuple import jax From 1062c72f3f4a9429404154d51d3b9081e87fe762 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:11:17 +0000 Subject: [PATCH 31/69] license headers --- .github/workflows/UnitTests.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/__init__.py | 15 +++++++++++++++ .../models/ltx_video/gradient_checkpoint.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/linear.py | 16 ++++++++++++++++ .../models/ltx_video/repeatable_layer.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/__init__.py | 15 +++++++++++++++ 7 files changed, 95 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..c1fa771d1 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index fa495ba1a..2dec16fa6 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + from absl import app from typing import Sequence import jax diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index ef8c530ba..ee7221652 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto from typing import Optional diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index 31b21cdd9..3503ab3b4 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Union, Iterable, Tuple, Optional, Callable import numpy as np diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index aaed41048..7e9cc80c4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ From 535c75eea64044e3b87fce1b4ebcee52b80600a1 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:34:49 +0000 Subject: [PATCH 32/69] exclude test --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c1fa771d1..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: From f6115df3dbee5c7c301224062abf447dcc6e25ca Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 18:45:31 +0000 Subject: [PATCH 33/69] auto script --- .../utils/convert_torch_weights_to_jax.py | 80 ++++++++++++++++++- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index 4e306c7ea..b66ec091e 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -2,21 +2,87 @@ import json from typing import Any, Dict, Optional + + import jax import jax.numpy as jnp from flax.training import train_state import optax import orbax.checkpoint as ocp from safetensors.torch import load_file +import requests +import shutil +from urllib.parse import urljoin -from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT +# from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax from huggingface_hub import hf_hub_download import os +import importlib +def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import): + """ + Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. + Args: + github_base_url (str): The base URL of the GitHub repo. + base_path (str): The base path where the new folder will be created. + target_folder_name (str): The name of the folder to create. + files_to_move (list): A list of file names to download and move. + module_to_import (str): The full module path to import. + """ + target_path = os.path.join(base_path, target_folder_name) + + try: + # Create the target directory + os.makedirs(target_path, exist_ok=True) + print(f"Created directory: {target_path}") + + # Download and move files + for file_name in files_to_move: + file_url = urljoin(github_base_url, file_name) + destination_path = os.path.join(target_path, file_name) + + try: + response = requests.get(file_url, stream=True) + response.raise_for_status() + + with open(destination_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"Downloaded and moved: {file_name} -> {destination_path}") + + except requests.exceptions.RequestException as e: + print(f"Error downloading {file_name}: {e}") + return # Stop if there is an error. + except OSError as e: + print(f"Error writing file {file_name}: {e}") + return # Stop if there is an error. + print("Files downloaded and moved successfully.") + + # Verify that the folder exists + if not os.path.exists(target_path): + print(f"Error: Target folder {target_path} does not exist after files download.") + # Dynamically import the module + try: + imported_module = importlib.import_module(module_to_import) + print(f"Module '{module_to_import}' imported successfully.") + # Access the class + transformer_class = getattr(imported_module, "Transformer3DModel") + print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") + return transformer_class + except ImportError as e: + print(f"Error importing module '{module_to_import}': {e}") + except AttributeError as e: + print(f"Error accessing class 'Transformer3DModel': {e}") + + except OSError as e: + print(f"Error during file system operation: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") class Checkpointer: """ Checkpointer - to load and store JAX checkpoints @@ -136,7 +202,15 @@ def main(args): "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " "training loss when resuming from the converted checkpoint." ) - + print("Downloading files from GitHub...") + github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" + ltx_repo_path = "../" + target_folder = "transformers_pytorch" + files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] + module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" + + Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) + print("Loading safetensors, flush = True") weight_file = "ltxv-13b-0.9.7-dev.safetensors" @@ -162,7 +236,7 @@ def main(args): if key in transformer_config: del transformer_config[key] - transformer = Transformer3DModel_PT.from_config(transformer_config) + transformer = Transformer3DModel.from_config(transformer_config) print("Loading torch weights into transformer..", flush=True) transformer.load_state_dict(torch_state_dict) From 8bf24a3b9f93a97e6b118a4b483f24a3e2b0001f Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 18:52:18 +0000 Subject: [PATCH 34/69] headers --- src/maxdiffusion/generate_ltx_video.py | 3 - src/maxdiffusion/max_utils.py | 61 +------- src/maxdiffusion/models/ltx_video/__init__.py | 16 ++ .../models/ltx_video/transformers/__init__.py | 15 ++ .../ltx_video/transformers/attention.py | 2 +- .../ltx_video/transformers/transformer3d.py | 2 +- .../utils/conversion_script_instruction.md | 7 +- .../utils/convert_torch_weights_to_jax.py | 139 ++++++++++-------- .../utils/diffusers_config_mapping.py | 16 ++ .../ltx_video/utils/skip_layer_strategy.py | 16 ++ .../models/ltx_video/utils/torch_compat.py | 16 ++ 11 files changed, 162 insertions(+), 131 deletions(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 9ed2c92f7..31712edd2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -112,6 +112,3 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": app.run(main) - - - diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..d4a80a347 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index e69de29bb..285b6e81c 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index e69de29bb..9ff757fc3 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index b5e9e030e..3b88270f6 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -452,7 +452,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 2d82e328c..5beb1b0d2 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -112,7 +112,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( #noqa: C408 + module_init_kwargs=dict( # noqa: C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, diff --git a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md index 3e571f7b7..a6ca08835 100644 --- a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md +++ b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md @@ -1,13 +1,10 @@ ### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions: -1. Create new tansformers_pytorch folder under models/ltx_video. -2. Move files from LTX repo, specifically, attention.py, embeddings.py, symmetric_patchifier.py, and transformer3d.py into the newly created folder. See here: https://github.com/Lightricks/LTX-Video/tree/main/ltx_video/models/transformers -3. Rename transformer3d.py to transformer_pt.py to distinguish from the pytorch version. Change classname to Transformer3DModel_PT. Also change classname in line "transformer = Transformer3DModel.from_config(transformer_config)" -4. Weight Downloading and Conversion +1. Weight Downloading and Conversion - If first time running (no local safetensors): \ In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json. - If already have local pytorch checkpoint: \ Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location -5. Restoring Jax Weights into transformer: +2. Restoring Jax Weights into transformer: - Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path. - Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder. diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index b66ec091e..82ff03bab 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -1,9 +1,24 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import argparse import json from typing import Any, Dict, Optional - import jax import jax.numpy as jnp from flax.training import train_state @@ -11,78 +26,78 @@ import orbax.checkpoint as ocp from safetensors.torch import load_file import requests -import shutil from urllib.parse import urljoin -# from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax from huggingface_hub import hf_hub_download import os import importlib + + def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import): - """ - Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. + """ + Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. + + Args: + github_base_url (str): The base URL of the GitHub repo. + base_path (str): The base path where the new folder will be created. + target_folder_name (str): The name of the folder to create. + files_to_move (list): A list of file names to download and move. + module_to_import (str): The full module path to import. + """ - Args: - github_base_url (str): The base URL of the GitHub repo. - base_path (str): The base path where the new folder will be created. - target_folder_name (str): The name of the folder to create. - files_to_move (list): A list of file names to download and move. - module_to_import (str): The full module path to import. - """ + target_path = os.path.join(base_path, target_folder_name) - target_path = os.path.join(base_path, target_folder_name) + try: + # Create the target directory + os.makedirs(target_path, exist_ok=True) + print(f"Created directory: {target_path}") + # Download and move files + for file_name in files_to_move: + file_url = urljoin(github_base_url, file_name) + destination_path = os.path.join(target_path, file_name) + + try: + response = requests.get(file_url, stream=True) + response.raise_for_status() + + with open(destination_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"Downloaded and moved: {file_name} -> {destination_path}") + + except requests.exceptions.RequestException as e: + print(f"Error downloading {file_name}: {e}") + except OSError as e: + print(f"Error writing file {file_name}: {e}") + print("Files downloaded and moved successfully.") + + # Verify that the folder exists + if not os.path.exists(target_path): + print(f"Error: Target folder {target_path} does not exist after files download.") + # Dynamically import the module try: - # Create the target directory - os.makedirs(target_path, exist_ok=True) - print(f"Created directory: {target_path}") - - # Download and move files - for file_name in files_to_move: - file_url = urljoin(github_base_url, file_name) - destination_path = os.path.join(target_path, file_name) - - try: - response = requests.get(file_url, stream=True) - response.raise_for_status() - - with open(destination_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - - print(f"Downloaded and moved: {file_name} -> {destination_path}") - - except requests.exceptions.RequestException as e: - print(f"Error downloading {file_name}: {e}") - return # Stop if there is an error. - except OSError as e: - print(f"Error writing file {file_name}: {e}") - return # Stop if there is an error. - print("Files downloaded and moved successfully.") - - # Verify that the folder exists - if not os.path.exists(target_path): - print(f"Error: Target folder {target_path} does not exist after files download.") - # Dynamically import the module - try: - imported_module = importlib.import_module(module_to_import) - print(f"Module '{module_to_import}' imported successfully.") - # Access the class - transformer_class = getattr(imported_module, "Transformer3DModel") - print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") - return transformer_class - except ImportError as e: - print(f"Error importing module '{module_to_import}': {e}") - except AttributeError as e: - print(f"Error accessing class 'Transformer3DModel': {e}") - - except OSError as e: - print(f"Error during file system operation: {e}") - except Exception as e: - print(f"An unexpected error occurred: {e}") + imported_module = importlib.import_module(module_to_import) + print(f"Module '{module_to_import}' imported successfully.") + # Access the class + transformer_class = getattr(imported_module, "Transformer3DModel") + print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") + return transformer_class + except ImportError as e: + print(f"Error importing module '{module_to_import}': {e}") + except AttributeError as e: + print(f"Error accessing class 'Transformer3DModel': {e}") + + except OSError as e: + print(f"Error during file system operation: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + class Checkpointer: """ Checkpointer - to load and store JAX checkpoints @@ -204,13 +219,13 @@ def main(args): ) print("Downloading files from GitHub...") github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" - ltx_repo_path = "../" + ltx_repo_path = "../" target_folder = "transformers_pytorch" files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) - + print("Loading safetensors, flush = True") weight_file = "ltxv-13b-0.9.7-dev.safetensors" diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py index 832bda051..81094d676 100644 --- a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main def make_hashable_key(dict_key): def convert_value(value): if isinstance(value, list): diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py index 476d38c75..74e74c1c6 100644 --- a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py index 23fa871d6..6cbfea70b 100644 --- a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import re from copy import copy from dataclasses import dataclass From 0f8483ee775bc929d102e81be6e616001b74db26 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 18:58:57 +0000 Subject: [PATCH 35/69] pulled --- src/maxdiffusion/configs/ltx_video.yml | 18 ++++----- src/maxdiffusion/max_utils.py | 23 +---------- .../ltx_video/transformers/attention.py | 30 +++++++------- src/maxdiffusion/pyconfig.py | 39 +++++++++++-------- 4 files changed, 44 insertions(+), 66 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 034aea65c..03edfd51f 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -27,33 +27,29 @@ output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 -ici_fsdp_transpose_parallelism: 1 -ici_sequence_parallelism: 1 -ici_tensor_transpose_parallelism: 1 -ici_expert_parallelism: 1 -ici_sequence_parallelism: 1 - diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index d4a80a347..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -417,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 3b88270f6..9faab1ded 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -18,6 +18,7 @@ import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto + import jax import jax.nn as jnn import jax.numpy as jnp @@ -213,7 +214,8 @@ def __call__( # Adaptive Norm if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 num_ada_params = self.scale_shift_table.shape[0] ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( batch_size, timestep.shape[1], num_ada_params, -1 @@ -452,7 +454,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821 + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") @@ -636,27 +638,20 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - None, - None, - ) # qkvo_sharding_spec = jax.sharding.PartitionSpec( # ("data", "fsdp", "fsdp_transpose", "expert"), # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), # None, # None, # ) - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # None, - # None, - # None, - # None, - # ) + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, @@ -841,7 +836,8 @@ def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: inner_dim = dim * self.mult if inner_dim < 256: raise ValueError("inner_dim must be at least 256") - inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 else: inner_dim = self.inner_dim diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From 7af151ab5755b744717e1e0e37c8de7e60492004 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:20:30 +0000 Subject: [PATCH 36/69] change base branch --- src/maxdiffusion/configs/ltx_video.yml | 94 +- src/maxdiffusion/generate_ltx_video.py | 370 ++-- .../ltx_video/transformers/attention.py | 1680 +++++++++-------- .../ltx_video/transformers/transformer3d.py | 560 +++--- 4 files changed, 1397 insertions(+), 1307 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 140fda56c..38e9012e8 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,18 +1,3 @@ -# Copyright 2025 Google LLC - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# https://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - #hardware hardware: 'tpu' skip_jax_distributed_system: False @@ -25,30 +10,91 @@ activations_dtype: 'bfloat16' run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False +#Checkpoints +text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +frame_rate: 30 +max_sequence_length: 512 +sampler: "from_checkpoint" + + -#parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] + + +# Generation parameters +pipeline_type: multi-scale +prompt: ["A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie.", "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."] +height: 512 +width: 512 +num_frames: 88 #344 +flow_shift: 5.0 +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +prompt_enhancement_words_threshold: 120 +# guidance_scale: [1, 1, 6, 8, 6, 1, 1] #4.5 +# stg_scale: [0, 0, 4, 4, 4, 2, 1] #1.0 +# rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] #0.7 +# num_inference_steps: 30 +# skip_final_inference_steps: 3 +# skip_initial_inference_steps: 0 +# guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] +# skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] +stg_mode: "attention_values" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +# cfg_star_rescale: True + + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + skip_initial_inference_steps: 0 + cfg_star_rescale: True + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + skip_final_inference_steps: 0 + cfg_star_rescale: True + +#Parallelism +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded + +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 @@ -65,4 +111,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True -enable_single_replica_ckpt_restoring: False +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 2dec16fa6..90c82747c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,202 +1,196 @@ -""" - Copyright 2025 Google LLC +import numpy as np +from absl import app +from typing import Sequence +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline +from maxdiffusion import pyconfig +import jax.numpy as jnp +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from huggingface_hub import hf_hub_download +import imageio +from datetime import datetime +from maxdiffusion.utils import export_to_video - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +import os +import json +import torch +from pathlib import Path - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: -from absl import app -from typing import Sequence -import jax -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import os -import functools -import jax.numpy as jnp -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import Mesh, PartitionSpec as P -import orbax.checkpoint as ocp - - -def validate_transformer_inputs( - prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids -): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) - print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return noise_pred + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding -def run(config): - key = jax.random.PRNGKey(42) - - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - base_dir = os.path.dirname(__file__) - - ##load in model config - config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - - # create dummy inputs: - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join( + char.lower() for char in text if char.isalpha() or char.isspace() ) - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - validate_transformer_inputs( - prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids - ) - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - - noise_pred = p_run_inference(states).block_until_ready() - print(noise_pred) # (4, 256, 128) + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + +def create_latent_upsampler(latent_upsampler_model_path: str, device: str): + latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) + latent_upsampler.to(device) + latent_upsampler.eval() + return latent_upsampler + +def get_unique_filename( + base: str, + ext: str, + prompt: str, + seed: int, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + for i in range(index_range): + filename = dir / \ + f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError( + f"Could not find a unique filename after {index_range} attempts." + ) + + +def run(config): + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding( + config.height, config.width, height_padded, width_padded) + # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold + # prompt_word_count = len(config.prompt.split()) + # enhance_prompt = ( + # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold + # ) + + seed = 10 # change this, generator in pytorch, used in prepare_latents + generator = torch.Generator().manual_seed(seed) + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) + if config.pipeline_type == "multi-scale": #move this to pipeline file?? + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile( + spatial_upscaler_model_name_or_path + ): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir= "/mnt/disks/diffusionproj", + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = create_latent_upsampler( + spatial_upscaler_model_path, "cpu" #device set to cpu for now + ) + pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) + stg_mode = config.stg_mode + if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": + skip_layer_strategy = SkipLayerStrategy.AttentionValues + elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": + skip_layer_strategy = SkipLayerStrategy.AttentionSkip + elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": + skip_layer_strategy = SkipLayerStrategy.Residual + elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": + skip_layer_strategy = SkipLayerStrategy.TransformerBlock + else: + raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") + # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, + # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, + # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images + images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config) + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, :config.num_frames, + pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + print(output_filename) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 9faab1ded..fba1cdfc3 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -15,10 +15,10 @@ # This implementation is based on the Torch version available at: # https://github.com/Lightricks/LTX-Video/tree/main from functools import partial +import functools import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto - import jax import jax.nn as jnn import jax.numpy as jnp @@ -40,876 +40,898 @@ ) + class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - - def __call__(self, x): - return x + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] - - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) - - attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) - - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states - -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) - - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", dtype=self.dtype, weight_dtype=self.weight_dtype, - name="to_out.0", matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any(device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print("Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 - assert cross_attention_kwargs.get("scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") - - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") - - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") - - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" - - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) - - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - - hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) - - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") - ) - - hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states + ) + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") -class AttentionOp(nn.Module): + batch_size = hidden_states.shape[0] - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # ("data", "fsdp", "fsdp_transpose", "expert"), - # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - # None, - # None, - # ) - qkvo_sharding_spec = jax.sharding.PartitionSpec( - "data", - "fsdp", - None, - "tensor", - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) - - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index = index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) - ** SRAM cache size for TPU - V5P - 1MB SRAM per core + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + block_index = -1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) -class ExplicitAttention(nn.Module): + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) #here skip_layer_mask is (48,3), changed this currently! + + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ - - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) - Returns: - jax.Array: Normed data - """ + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + None, + None, + None, + None, + ) + qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - return hidden_states +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError(f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out + return out \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index d6c7cf4c4..5213beb7d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -29,277 +29,305 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( # noqa C408 - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def init_weights(self, in_channels, key, caption_channels, eval_only=True): - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "hidden_states": (batch_size, num_tokens, in_channels), - "indices_grid": (batch_size, 3, num_tokens), - "encoder_hidden_states": (batch_size, 128, caption_channels), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - if eval_only: - return jax.eval_shape( - self.init, - key, - **example_inputs, - )["params"] - else: - return self.init(key, **example_inputs)["params"] - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - - if self.num_layers > 0: - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - ) - # Output processing - - scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + def init_weights(self, in_channels, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + jax.random.PRNGKey(42), ##need to change? + **example_inputs, + )["params"] + else: + return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones( + (self.num_layers, batch_size * num_conds), dtype=self.dtype + ) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + - Args: - x (jax.Array): log value - base (jax.Array): base of the log - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - dtype = jnp.float32 # We need full precision in the freqs_cis computation. - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) From 634591b86ab7ef1fda2c3a7ac803b37ce0cd80ee Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:46:25 +0000 Subject: [PATCH 37/69] save now --- src/maxdiffusion/configs/ltx_video.yml | 23 +- .../models/ltx_video/autoencoders/__init__.py | 0 .../ltx_video/autoencoders/causal_conv3d.py | 63 + .../autoencoders/causal_video_autoencoder.py | 1398 +++++++++++++++++ .../ltx_video/autoencoders/conv_nd_factory.py | 90 ++ .../ltx_video/autoencoders/dual_conv3d.py | 217 +++ .../autoencoders/latent_upsampler.py | 203 +++ .../ltx_video/autoencoders/pixel_norm.py | 12 + .../ltx_video/autoencoders/pixel_shuffle.py | 33 + .../models/ltx_video/autoencoders/vae.py | 380 +++++ .../ltx_video/autoencoders/vae_encode.py | 247 +++ .../autoencoders/video_autoencoder.py | 1045 ++++++++++++ .../ltx_video/transformers/attention.py | 9 +- .../transformers/symmetric_patchifier.py | 84 + .../ltx_video/transformers/transformer3d.py | 6 +- .../utils/diffusers_config_mapping.py | 174 ++ .../ltx_video/utils/prompt_enhance_utils.py | 226 +++ .../ltx_video/utils/skip_layer_strategy.py | 8 + .../models/ltx_video/utils/torch_utils.py | 25 + .../pipelines/ltx_video/__init__.py | 0 .../pipelines/ltx_video/ltx_video_pipeline.py | 1374 ++++++++++++++++ .../schedulers/scheduling_rectified_flow.py | 357 +++++ 22 files changed, 5953 insertions(+), 21 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_utils.py create mode 100644 src/maxdiffusion/pipelines/ltx_video/__init__.py create mode 100644 src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py create mode 100644 src/maxdiffusion/schedulers/scheduling_rectified_flow.py diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 38e9012e8..cba635a1a 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -24,7 +24,7 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: ["A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie.", "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."] +prompt: "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage." height: 512 width: 512 num_frames: 88 #344 @@ -68,34 +68,29 @@ second_pass: skip_final_inference_steps: 0 cfg_star_rescale: True -#Parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 - -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 -ici_fsdp_transpose_parallelism: 1 -ici_sequence_parallelism: 1 -ici_tensor_transpose_parallelism: 1 -ici_expert_parallelism: 1 -ici_sequence_parallelism: 1 - diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py new file mode 100644 index 000000000..98249c2f5 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -0,0 +1,63 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py new file mode 100644 index 000000000..1255b6d34 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1398 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper +from maxdiffusion.models.ltx_video.transformers.attention import Attention +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "CausalVideoAutoencoder" + ), "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError( + f"latent_log_var={latent_log_var} requires use_quant_conv=False" + ) + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels + // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_space", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_time_res", + ] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } + ckpt_state_dict = { + key: value + for key, value in state_dict.items() + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + + model_keys = set(name for name, _ in self.named_modules()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name.startswith("compress"): + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.last_scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError( + "attention_head_dim must be less than or equal to in_channels" + ) + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, frames * height * width + ).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad( + hidden_states, (0, 0, 0, pad_len), "constant", 0 + ) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=( + None if not attention.use_tpu_flash_attention else mask + ), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, frames, height, width + ) + else: + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = ( + np.prod(stride) * in_channels // out_channels_reduction_factor + ) + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_demo_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("compress_time_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape, timestep=timestep + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode( + image_latent, target_shape=image_latent.shape, timestep=timestep + ).sample + + first_frame_latent = latent[:, :, :1, :, :] + + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py new file mode 100644 index 000000000..1aa55ed9c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -0,0 +1,90 @@ +from typing import Tuple, Union + +import torch + +from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d +from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", +): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py new file mode 100644 index 000000000..dcf889296 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -0,0 +1,217 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py new file mode 100644 index 000000000..8cb7d7d68 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -0,0 +1,203 @@ +from typing import Optional, Union +from pathlib import Path +import os +import json + +import torch +import torch.nn as nn +from einops import rearrange +from diffusers import ConfigMixin, ModelMixin +from safetensors.torch import safe_open + +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler + + +if __name__ == "__main__": + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py new file mode 100644 index 000000000..9bc3ea60e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py new file mode 100644 index 000000000..4e79ae284 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from einops import rearrange + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py new file mode 100644 index 000000000..821a6b32b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -0,0 +1,380 @@ +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd( + quant_dims, 2 * latent_channels, 2 * latent_channels, 1 + ) + self.post_quant_conv = make_conv_nd( + quant_dims, latent_channels, latent_channels, 1 + ) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert ( + z_sample_size % 8 == 0 or z_sample_size == 1 + ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( + 1 - z / blend_extent + ) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = ( + sizes + [z.shape[2] - sum(sizes)] + if z.shape[2] - sum(sizes) > 0 + else sizes + ) + tiles = z.split(sizes, dim=2) + moments_tiles = [ + ( + self._hw_tiled_encode(z_tile, return_dict) + if self.use_hw_tiling + else self._encode(z_tile) + ) + for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = ( + self._hw_tiled_encode(z, return_dict) + if self.use_hw_tiling + else self._encode(z) + ) + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t + * 2 + ** ( + len(self.encoder.down_blocks) + - 1 + - math.sqrt(self.encoder.patch_size) + ) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode(z, target_shape) + if self.use_hw_tiling + else self._decode(z, target_shape=target_shape, timestep=timestep) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py new file mode 100644 index 000000000..5a0aeeccf --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -0,0 +1,247 @@ +from typing import Tuple +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + + +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + latents = vae.encode(media_items).latent_dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len( + [ + block + for block in vae.encoder.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + spatial = vae.config.patch_size * 2**down_blocks + temporal = ( + vae.config.patch_size_t * 2**down_blocks + if isinstance(vae, VideoAutoencoder) + else 1 + ) + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords( + latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors( + latent_coords, scale_factors, causal_fix + ) + return pixel_coords + + +def latent_to_pixel_coords_from_factors( + latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False +) -> Tensor: + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py new file mode 100644 index 000000000..5b9ea640b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -0,0 +1,1045 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from maxdiffusion.models.ltx_video.utils.torch_utils import Identity +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "VideoAutoencoder" + ), "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels + // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels + // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels + for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + * self.patch_size + ) + + def forward( + self, sample: torch.FloatTensor, return_features=False + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)( + sample, downsample_in_time=downsample_in_time + ) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block + and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, block_out_channels[0], out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward( + self, hidden_states: torch.FloatTensor, downsample_in_time + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample( + hidden_states, downsample_in_time=downsample_in_time + ) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D( + dims=dims, channels=out_channels, out_channels=out_channels + ) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, upsample_in_time=True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd( + dims, channels, out_channels, kernel_size=3, padding=1, bias=True + ) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange( + x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d + ) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index fba1cdfc3..caff27add 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -670,12 +670,13 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): # ) # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") qkvo_sharding_spec = jax.sharding.PartitionSpec( - None, - None, - None, + "data", + "fsdp", None, + "tensor", ) - qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py new file mode 100644 index 000000000..2eca32033 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 5213beb7d..7d068b0f6 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -153,7 +153,7 @@ def scale_shift_table_init(key): weight_dtype=self.weight_dtype, matmul_precision=self.matmul_precision, ) - def init_weights(self, in_channels, caption_channels, eval_only=True): + def init_weights(self, key, in_channels, caption_channels, eval_only=True): example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { @@ -172,11 +172,11 @@ def init_weights(self, in_channels, caption_channels, eval_only=True): if eval_only: return jax.eval_shape( self.init, - jax.random.PRNGKey(42), ##need to change? + key, ##need to change? **example_inputs, )["params"] else: - return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] + return self.init(key, **example_inputs)['params'] def create_skip_layer_mask( self, diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..53c0082d1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py new file mode 100644 index 000000000..901051728 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -0,0 +1,226 @@ +import logging +from typing import Union, List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + +I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Keep within 150 words. +For best results, build your prompts using this structure: +Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Align to the image caption if it contradicts the user text input. +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + + +def tensor_to_pil(tensor): + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 + + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 + + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) + + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + + # Convert to PIL Image + return Image.fromarray(numpy_image) + + +def generate_cinematic_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompt: Union[str, List[str]], + conditioning_items: Optional[List] = None, + max_new_tokens: int = 256, +) -> List[str]: + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item( + first_frame_conditioning_item + ) + + assert len(first_frames) == len( + prompts + ), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts + + +def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: + frames_tensor = conditioning_item.media_item + return [ + tensor_to_pil(frames_tensor[i, :, 0, :, :]) + for i in range(frames_tensor.shape[0]) + ] + + +def _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + first_frames: List[Image.Image], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + image_captions = _generate_image_captions( + image_caption_model, image_caption_processor, first_frames + ) + + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_image_captions( + image_caption_model, + image_caption_processor, + images: List[Image.Image], + system_prompt: str = "", +) -> List[str]: + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor( + image_caption_prompts, images, return_tensors="pt" + ).to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int +) -> List[str]: + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate( + **model_inputs, max_new_tokens=max_new_tokens + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, outputs) + ] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + ) + + return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..30f9016e1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py new file mode 100644 index 000000000..991b07c36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/src/maxdiffusion/pipelines/ltx_video/__init__.py b/src/maxdiffusion/pipelines/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py new file mode 100644 index 000000000..d0a4ea6da --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -0,0 +1,1374 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +import os +import random +from jax import Array +from datetime import datetime +from pathlib import Path +from diffusers import AutoencoderKL +from typing import Optional, List, Union, Tuple +from einops import rearrange +import torch.nn.functional as F +from diffusers.utils.torch_utils import randn_tensor +# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +import yaml +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, + T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) + + +import imageio +import json +import numpy as np +import torch +from safetensors import safe_open +from PIL import Image +from transformers import ( + T5EncoderModel, + T5Tokenizer, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) +from huggingface_hub import hf_hub_download +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + vae_decode, + vae_encode, + un_normalize_latents, + normalize_latents, +) +from diffusers.image_processor import VaeImageProcessor +from ltx_video.schedulers.rf import RectifiedFlowScheduler +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +import ltx_video.pipelines.crf_compressor as crf_compressor +from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from math import e +from types import NoneType +from typing import Any, Dict +import numpy as np +import inspect + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec as P +from typing import Optional, Union, List +import torch +from maxdiffusion.checkpointing import checkpointing_utils +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ...pyconfig import HyperParameters +# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState +from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState +from ...max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations +) +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import json +import functools +import orbax.checkpoint as ocp +import pickle + + +class PickleCheckpointHandler(ocp.CheckpointHandler): + def save(self, directory: str, item, args=None): + os.makedirs(directory, exist_ok=True) + with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: + pickle.dump(item, f) + + def restore(self, directory: str, args=None): + with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: + return pickle.load(f) + + def structure(self, directory: str): + return {} # not needed for simple pickle-based handling + + +def save_tensor_dict(tensor_dict, timestep): + base_dir = os.path.dirname(__file__) + local_path = os.path.join(base_dir, f"schedulerTest{timestep}") + + try: + torch.save(tensor_dict, local_path) + print(f"Dictionary of tensors saved to: {local_path}") + except Exception as e: + print(f"Error saving dictionary: {e}") + raise + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", + encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def prepare_extra_step_kwargs(generator): + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + +class LTXVideoPipeline: + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: AutoencoderKL, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( + self.vae + ) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + # "Wan-AI/Wan2.1-T2V-14B-Diffusers", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( + # "Lightricks/LTX-Video", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # import pdb; pdb.set_trace() + # scheduler = FlaxRectifiedFlowMultistepScheduler( + # sampler="LinearQuadratic" + # ) + # scheduler_state = scheduler.create_state() + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", + "causal_temporal_positioning", "in_channels", "ckpt_path"] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + transformer = Transformer3DModel( + # change this sharding back + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) + + weights_init_fn = functools.partial( + transformer.init_weights, + jax.random.PRNGKey(42), + in_channels, + model_config['caption_channels'], + eval_only=True + ) + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put( + transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + return vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + # text_encoder = T5EncoderModel.from_pretrained( + # ckpt_path, subfolder="text_encoder" + # ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + # tokenizer = T5Tokenizer.from_pretrained( + # ckpt_path, subfolder="tokenizer" + # ) + t5_tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, max_length=config.max_sequence_length, use_fast=True + ) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer( + config) + + # load from pytorch version + models_dir = "/mnt/disks/diffusionproj" #edit this! + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + vae = cls.load_vae(ltxv_model_path) + vae = vae.to(torch.bfloat16) + text_encoder = cls.load_text_encoder( + config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer( + config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( + config) + else: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + """ + Perform the denoising step for the required tokens, based on the current timestep and + conditioning mask: + Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) + and will start to be denoised when the current timestep is equal or lower than their + conditioning timestep. + (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) + """ + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( #currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps # Use the original num_inference_steps here for the check + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[ + skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + ] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) + + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = ( + text_encoder_max_tokens # TPU supports only lengths multiple of 128 + ) + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + # text_enc_device = next(self.text_encoder.parameters()) + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, max_length - 1: -1] + ) + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + print(isinstance(prompt_embeds, Array)) + # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + # prompt_embeds = prompt_embeds.view( + # bs_embed * num_images_per_prompt, seq_len, -1 + # ) + prompt_embeds = jnp.reshape( + prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile( + prompt_attention_mask, (1, num_images_per_prompt)) + # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + # prompt_attention_mask = prompt_attention_mask.view( + # bs_embed * num_images_per_prompt, -1 + # ) + prompt_attention_mask = jnp.reshape( + prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance hasn't changed yet + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, + (1, num_images_per_prompt, 1) + ) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, + (batch_size * num_images_per_prompt, seq_len, -1) + ) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, + (1, num_images_per_prompt) + ) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, + (bs_embed * num_images_per_prompt, -1) + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds,#(1, 256, 4096) + prompt_attention_mask, #1, 256 + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items, + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert ( + latents.shape == latent_shape + ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor( + (b, f * h * w, c), generator=generator, device=device, dtype=dtype + ) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, CausalVideoAutoencoder) + + # if conditioning_items: + # batch_size, _, num_latent_frames = init_latents.shape[:3] + + # init_conditioning_mask = torch.zeros( + # init_latents[:, 0, :, :, :].shape, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents = [] + # extra_conditioning_pixel_coords = [] + # extra_conditioning_mask = [] + # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # # Process each conditioning item + # for conditioning_item in conditioning_items: + # conditioning_item = self._resize_conditioning_item( + # conditioning_item, height, width + # ) + # media_item = conditioning_item.media_item + # media_frame_number = conditioning_item.media_frame_number + # strength = conditioning_item.conditioning_strength + # assert media_item.ndim == 5 # (b, c, f, h, w) + # b, c, n_frames, h, w = media_item.shape + # assert ( + # height == h and width == w + # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + # assert n_frames % 8 == 1 + # assert ( + # media_frame_number >= 0 + # and media_frame_number + n_frames <= num_frames + # ) + + # # Encode the provided conditioning media item + # media_item_latents = vae_encode( + # media_item.to(dtype=self.vae.dtype, device=self.vae.device), + # self.vae, + # vae_per_channel_normalize=vae_per_channel_normalize, + # ).to(dtype=init_latents.dtype) + + # # Handle the different conditioning cases + # if media_frame_number == 0: + # # Get the target spatial position of the latent conditioning item + # media_item_latents, l_x, l_y = self._get_latent_spatial_position( + # media_item_latents, + # conditioning_item, + # height, + # width, + # strip_latent_border=True, + # ) + # b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # # First frame or sequence - just update the initial noise latents and the mask + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + # torch.lerp( + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + # media_item_latents, + # strength, + # ) + # ) + # init_conditioning_mask[ + # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + # ] = strength + # else: + # # Non-first frame or sequence + # if n_frames > 1: + # # Handle non-first sequence. + # # Encoded latents are either fully consumed, or the prefix is handled separately below. + # ( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # ) = self._handle_non_first_conditioning_sequence( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # media_frame_number, + # strength, + # ) + + # # Single frame or sequence-prefix latents + # if media_item_latents is not None: + # noise = randn_tensor( + # media_item_latents.shape, + # generator=generator, + # device=media_item_latents.device, + # dtype=media_item_latents.dtype, + # ) + + # media_item_latents = torch.lerp( + # noise, media_item_latents, strength + # ) + + # # Patchify the extra conditioning latents and calculate their pixel coordinates + # media_item_latents, latent_coords = self.patchifier.patchify( + # latents=media_item_latents + # ) + # pixel_coords = latent_to_pixel_coords( + # latent_coords, + # self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, + # ) + + # # Update the frame numbers to match the target frame number + # pixel_coords[:, 0] += media_frame_number + # extra_conditioning_num_latents += media_item_latents.shape[1] + + # conditioning_mask = torch.full( + # media_item_latents.shape[:2], + # strength, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents.append(media_item_latents) + # extra_conditioning_pixel_coords.append(pixel_coords) + # extra_conditioning_mask.append(conditioning_mask) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify( + latents=init_latents + ) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True + + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + # init_conditioning_mask, _ = self.patchifier.patchify( + # latents=init_conditioning_mask.unsqueeze(1) + # ) + # init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + # if extra_conditioning_latents: + # # Stack the extra conditioning latents, pixel coordinates and mask + # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + # init_pixel_coords = torch.cat( + # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + # ) + # init_conditioning_mask = torch.cat( + # [*extra_conditioning_mask, init_conditioning_mask], dim=1 + # ) + + # if self.transformer.use_tpu_flash_attention: + # # When flash attention is used, keep the original number of tokens by removing + # # tokens from the end. + # init_latents = init_latents[:, :-extra_conditioning_num_latents] + # init_pixel_coords = init_pixel_coords[ + # :, :, :-extra_conditioning_num_latents + # ] + # init_conditioning_mask = init_conditioning_mask[ + # :, :-extra_conditioning_num_latents + # ] + + # return ( + # init_latents, + # init_pixel_coords, + # init_conditioning_mask, + # extra_conditioning_num_latents, + # ) + + # change the paramters of these, currently pass in dummy inputs + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get( + "vae_per_channel_normalize", True) + import pdb; pdb.set_trace() + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) + + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [ + i for i, val in enumerate(guidance_timesteps) if val <= timestep + ] + guidance_mapping.append( + indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) + ) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask( + batch_size, num_conds, num_conds - 1, skip_blocks + ) + for skip_blocks in skip_block_list + ] + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate( + [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension + ) + prompt_attention_mask_batch = jnp.concatenate( + [negative_prompt_attention_mask, prompt_attention_mask], axis=0 + ) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = ( + self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + ) + + extra_step_kwargs = prepare_extra_step_kwargs( + generator=jax.random.PRNGKey(0)) + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + # if not isinstance(guidance_scale, List): + # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) + # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + # latents = jax.device_put(example_inputs["latents"], data_sharding) + # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + noise_cond = jnp.ones( # initialize first round with this! + (1, 1) + ) + + # # noise_cond = None + # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" + # tensor_dict = torch.load(saved_tensor_path) + + # for key, value in tensor_dict.items(): + # if value is not None: + # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) + # example_inputs = tensor_dict + # latents = jax.device_put(example_inputs["latent_model_input"]) + # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) + # fractional_coords = jax.device_put(example_inputs["indices_grid"]) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) + # segment_ids = None + # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + + # #only run this for the first time! + # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) + # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken + # scheduler_state = self.scheduler_state + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = jax.jit( + # functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # ), + # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct + # out_shardings=None, + # ) + segment_ids = None + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # # guidance_scale=guidance_scale + # ) + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array( + fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling = do_rescaling, + rescaling_scale = rescaling_scale, + batch_size=batch_size, + skip_layer_masks = skip_layer_masks, + skip_layer_strategy = skip_layer_strategy, + cfg_star_rescale = cfg_star_rescale + ) + + with self.mesh: + latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( + # add scheduler state back in + torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) + latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] + // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [ + decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor( + decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to( + latents.device + )[:, None, None, None, None] + latents = ( + latents * (1 - decode_noise_scale) + + noise * decode_noise_scale + ) + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs.get( + "vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + image = self.image_processor.postprocess( + image, output_type=output_type) # shape mismatch here + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + # save states here + + +def transformer_forward_pass( # need to jit this? wan didnt + latents, + state, + noise_cond, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask, + skip_layer_strategy, +): + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy + ) # need .param here? + return noise_pred, state + + +def run_inference( + transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale +): + # do_classifier_free_guidance = guidance_scale > 1.0 + # for step in range(num_inference_steps): + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = ( + jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + ) + # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) + # timestep = jnp.broadcast_to( + # t, (latent_model_input.shape[0],) + # ).reshape(-1, 1) + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + # Determine the correct dtype based on the device (similar to PyTorch) + is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension (compatible with ONNX/Core ML) + current_timestep = jnp.broadcast_to( + current_timestep, (latent_model_input.shape[0],1) + ) + + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( + skip_layer_masks[i] + if skip_layer_masks is not None + else None + ), skip_layer_strategy = skip_layer_strategy) + # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) + + # # latents = self.denoising + # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated + if do_spatio_temporal_guidance: + # JAX uses jnp.split for splitting along an axis. + # Equivalent to .chunk() for equal-sized chunks + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum( #(1, 1) + positive_flat * negative_flat, axis=1, keepdims=True + ) + squared_norm = ( #(1, 1) + jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + ) + alpha = dot_product / squared_norm #might need to reshape this (1, 1) + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) + noise_pred = noise_pred_uncond + guidance_scale[i] * ( + noise_pred_text - noise_pred_uncond + ) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] # JAX slicing is similar + latents, scheduler_state = scheduler.step( + scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + +def adain_filter_latent( + latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 +): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean( + reference_latents[i, c], dim=None + ) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + +class LTXMultiScalePipeline: + def _upsample_latents( + self, latest_upsampler: LatentUpsampler, latents: torch.Tensor + ): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents( + latents, self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents( + upsampled_latents, self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents + + def __init__( + self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + ): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + self.latent_upsampler = latent_upsampler + + def __call__( + self, + height, + width, + num_frames, + is_video, + output_type, + generator, + config + ) -> Any: + + original_output_type = output_type + original_width = width + original_height = height + x_width = int(width * config.downscale_factor) + downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) + x_height = int(height * config.downscale_factor) + downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) + #use original height and width here to see + output_type = 'latent' + result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, + is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], + num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) + latents = result + upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents + ) + + + + latents = upsampled_latents + output_type = original_output_type + width = downscaled_width * 2 + height = downscaled_height * 2 + + result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, + is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], + num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(original_height, original_width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py new file mode 100644 index 000000000..1624b81c4 --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -0,0 +1,357 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from pathlib import Path +import os +from safetensors import safe_open + +import flax +import jax +import jax.numpy as jnp +import json +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.utils import is_scipy_available +from maxdiffusion.schedulers.scheduling_utils_flax import ( + CommonSchedulerState, + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + +def linear_quadratic_schedule_jax(num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) + +def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + +def _prod_jax(iterable): + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + +def get_normal_shift_jax( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + + + +def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) + + return stretched_shifts + +def sd3_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + target_shift_terminal: Optional[float] = None, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) + + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts + + +def simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + n: int = 32 * 32, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + +@flax.struct.dataclass +class RectifiedFlowSchedulerState: + """ + Data class to hold the mutable state of the RectifiedFlowScheduler. + """ + + common: CommonSchedulerState + init_noise_sigma: float + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + + + + @classmethod + def create( #need to change this! + cls, + common_state: CommonSchedulerState, + init_noise_sigma: float + ): + return cls( + common = common_state, + init_noise_sigma = init_noise_sigma, + num_inference_steps = None, + timesteps = None, + sigmas = None, + ) + + +@dataclass +class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): + state: RectifiedFlowSchedulerState + + +class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + + dtype: jnp.dtype + order = 1 + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + beta_schedule: str = "linear", + rescale_zero_terminal_snr: bool = False, + beta_start: float = 0.0001, + beta_end: float = 0.02, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + sampler: Optional[str] = "Uniform", + shift: Optional[float] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + init_noise_sigma = 1.0 + return RectifiedFlowSchedulerState.create(common_state = common, init_noise_sigma=init_noise_sigma) + + + def get_initial_timesteps_jax( + self, num_timesteps: int, shift: Optional[float] = None + ) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax( + shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + ).astype(self.dtype) + else: + # This should be caught by __init__ but for safety + raise ValueError(f"Sampler {self.config.sampler} is not supported.") + + def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax( + samples_shape, timesteps, self.config.target_shift_terminal + ) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape, timesteps, self.config.base_resolution + ) + return timesteps + + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + config = None + if pretrained_model_path.is_file(): + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata['config']) + config = configs["scheduler"] + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = ( + pretrained_model_path / "scheduler" / "scheduler_config.json" + ) + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError( + f"Scheduler config not found at {diffusers_noise_scheduler_config_path}" + ) + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config + return FlaxRectifiedFlowMultistepScheduler.from_config(config) + + def set_timesteps( + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError( + "You cannot provide both `timesteps` and `num_inference_steps`." + ) + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum( + self.config.num_train_timesteps, num_inference_steps + ) + timesteps = self.get_initial_timesteps_jax( + num_inference_steps, shift=self.config.shift + ).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + + def scale_model_input( + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + + + def step( + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, # Can be global or per-token, but for RF it's typically global. + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side='right') + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), + timesteps_padded[current_t_idx + 1], + 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side='right') # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) From a272d08f06ffb602c11572fbe4d6ab468aeb3e28 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:57:03 +0000 Subject: [PATCH 38/69] load transformer error --- src/maxdiffusion/models/ltx_video/transformers/transformer3d.py | 1 + src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 7d068b0f6..12f035031 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -154,6 +154,7 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) def init_weights(self, key, in_channels, caption_channels, eval_only=True): + import pdb; pdb.set_trace() example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index d0a4ea6da..422005b19 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -240,7 +240,7 @@ def load_transformer(cls, config): tx=None, config=config, mesh=mesh, - weights_init_fn=weights_init_fn, + weights_init_fn=None, checkpoint_manager=checkpoint_manager, checkpoint_item=" ", model_params=None, From 4bcffd11b5b5a2dffd2526080ccf092e5ad7eb4a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:01:04 +0000 Subject: [PATCH 39/69] later --- .../models/ltx_video/transformers/transformer3d.py | 10 ++++------ .../models/ltx_video/xora_v1.2-13B-balanced-128.json | 2 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 9 +++------ 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 12f035031..e3110a82b 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -153,8 +153,7 @@ def scale_shift_table_init(key): weight_dtype=self.weight_dtype, matmul_precision=self.matmul_precision, ) - def init_weights(self, key, in_channels, caption_channels, eval_only=True): - import pdb; pdb.set_trace() + def init_weights(self, in_channels, key, caption_channels, eval_only=True): example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { @@ -169,16 +168,15 @@ def init_weights(self, key, in_channels, caption_channels, eval_only=True): example_inputs[name] = jnp.ones( shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool ) - + if eval_only: return jax.eval_shape( self.init, - key, ##need to change? + key, **example_inputs, )["params"] else: - return self.init(key, **example_inputs)['params'] - + return self.init(key, **example_inputs)["params"] def create_skip_layer_mask( self, batch_size: int, diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index c5b3c0ef9..bce38fb20 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,5 +1,5 @@ { - "ckpt_path": "", + "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 422005b19..5443a539d 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -226,12 +226,9 @@ def load_transformer(cls, config): **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) weights_init_fn = functools.partial( - transformer.init_weights, - jax.random.PRNGKey(42), - in_channels, - model_config['caption_channels'], - eval_only=True + transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) @@ -240,7 +237,7 @@ def load_transformer(cls, config): tx=None, config=config, mesh=mesh, - weights_init_fn=None, + weights_init_fn=weights_init_fn, checkpoint_manager=checkpoint_manager, checkpoint_item=" ", model_params=None, From f5afa917d55b1095eeef8a0e40252ed8e3c031a7 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:09:49 +0000 Subject: [PATCH 40/69] changed repeatable layer --- .../models/ltx_video/repeatable_layer.py | 161 +++++++++--------- .../pipelines/ltx_video/ltx_video_pipeline.py | 2 +- 2 files changed, 83 insertions(+), 80 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 7e9cc80c4..5e2ceb789 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,118 +1,121 @@ -# Copyright 2025 Lightricks Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This implementation is based on the Torch version available at: -# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional import jax from flax import linen as nn +import jax.numpy as jnp from flax.linen import partitioning class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - @nn.compact - def __call__(self, *args) -> Tuple[jax.Array, None]: - """ - jax carry-op format of block - assumes the input contains an input tensor to the block along with kwargs that might be send to the block - kwargs are assumed to have static role, while the input changes between cycles + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry - Returns: - Tuple[jax.Array, None]: Output tensor from the block - """ - mod = self.module(*self.module_init_args, **self.module_init_kwargs) - output = mod(*args) - return output, None + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + # block_args are the static arguments passed to each individual block + output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module + + next_index = index_input + 1 + new_carry = (output_data, next_index) + + + return new_carry, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} - - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, # Separate params per timestep - split_rngs={"params": True}, - in_axes=(nn.broadcast,) * (len(args) - 1), - length=self.num_layers, - **scan_kwargs, - ) - wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) - x, _ = wrapped_function(*args) - return x + @nn.compact + def __call__(self, *args): # args is now the full input to RepeatableLayer + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + + initial_data_input = args[0] # The first element is your main data input + static_block_args = args[1:] # Any subsequent elements are static args for each block + + initial_index = jnp.array(0, dtype=jnp.int32) + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + + # in_axes for the scanned function (RepeatableCarryBlock.__call__): + # 1. The 'carry' tuple ((0, 0)) + # 2. Then, nn.broadcast for each of the `static_block_args` + in_axes_for_scan = (nn.broadcast,) * (len(args)-1) + + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) + + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + + # Typically, you only want the final data output from the sequence of layers + return final_data \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 5443a539d..67ad161ce 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -228,7 +228,7 @@ def load_transformer(cls, config): weights_init_fn = functools.partial( transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) - + import pdb; pdb.set_trace() absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) From e8050346296831e5c8485081d38f8704efe66fbf Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:38:39 -0700 Subject: [PATCH 41/69] Update max_utils.py --- src/maxdiffusion/max_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -609,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() From bb61ecb24287a2a6beae81634fa5a7422de53952 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:51:33 +0000 Subject: [PATCH 42/69] functional --- src/maxdiffusion/checkpointing/checkpointing_utils.py | 7 +++++-- src/maxdiffusion/max_utils.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..5072b4639 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..e13f31f94 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( From 7d4b2a9782ed4389341ce58e6d3abbcc02fa0083 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 22:02:55 +0000 Subject: [PATCH 43/69] moved upsampler --- .../pipelines/ltx_video/ltx_video_pipeline.py | 140 ++++++------------ 1 file changed, 45 insertions(+), 95 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 67ad161ce..35750ad5e 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -11,37 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import math import os -import random from jax import Array -from datetime import datetime -from pathlib import Path +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from diffusers import AutoencoderKL from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor -# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -import yaml -from transformers import (CLIPTokenizer, FlaxCLIPTextModel, - T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) - - -import imageio -import json -import numpy as np -import torch -from safetensors import safe_open -from PIL import Image from transformers import ( - T5EncoderModel, - T5Tokenizer, + FlaxT5EncoderModel, + AutoTokenizer, AutoModelForCausalLM, AutoProcessor, - AutoTokenizer, -) + AutoTokenizer,) +import json +import numpy as np +import torch from huggingface_hub import hf_hub_download from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, @@ -55,27 +42,19 @@ normalize_latents, ) from diffusers.image_processor import VaeImageProcessor -from ltx_video.schedulers.rf import RectifiedFlowScheduler from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler -import ltx_video.pipelines.crf_compressor as crf_compressor from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt from math import e from types import NoneType from typing import Any, Dict import numpy as np -import inspect import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P from typing import Optional, Union, List -import torch -from maxdiffusion.checkpointing import checkpointing_utils -from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...pyconfig import HyperParameters -# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState from ...max_utils import ( create_device_mesh, @@ -83,50 +62,9 @@ get_memory_allocations ) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import os import json import functools import orbax.checkpoint as ocp -import pickle - - -class PickleCheckpointHandler(ocp.CheckpointHandler): - def save(self, directory: str, item, args=None): - os.makedirs(directory, exist_ok=True) - with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: - pickle.dump(item, f) - - def restore(self, directory: str, args=None): - with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: - return pickle.load(f) - - def structure(self, directory: str): - return {} # not needed for simple pickle-based handling - - -def save_tensor_dict(tensor_dict, timestep): - base_dir = os.path.dirname(__file__) - local_path = os.path.join(base_dir, f"schedulerTest{timestep}") - - try: - torch.save(tensor_dict, local_path) - print(f"Dictionary of tensors saved to: {local_path}") - except Exception as e: - print(f"Error saving dictionary: {e}") - raise - - -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) - print("encoder_attention_segment_ids.shape: ", - encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - def prepare_extra_step_kwargs(generator): extra_step_kwargs = {} @@ -817,7 +755,6 @@ def __call__( skip_initial_inference_steps: int = 0, skip_final_inference_steps: int = 0, cfg_star_rescale: bool = False, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, ): @@ -1076,7 +1013,6 @@ def __call__( rescaling_scale = rescaling_scale, batch_size=batch_size, skip_layer_masks = skip_layer_masks, - skip_layer_strategy = skip_layer_strategy, cfg_star_rescale = cfg_star_rescale ) @@ -1149,7 +1085,6 @@ def transformer_forward_pass( # need to jit this? wan didnt segment_ids, encoder_attention_segment_ids, skip_layer_mask, - skip_layer_strategy, ): noise_pred = transformer.apply( {"params": state.params}, @@ -1160,13 +1095,12 @@ def transformer_forward_pass( # need to jit this? wan didnt segment_ids=segment_ids, encoder_attention_segment_ids=encoder_attention_segment_ids, skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy - ) # need .param here? + ) return noise_pred, state def run_inference( - transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale + transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale ): # do_classifier_free_guidance = guidance_scale > 1.0 # for step in range(num_inference_steps): @@ -1206,7 +1140,7 @@ def run_inference( skip_layer_masks[i] if skip_layer_masks is not None else None - ), skip_layer_strategy = skip_layer_strategy) + )) # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) # # latents = self.denoising @@ -1294,6 +1228,31 @@ def adain_filter_latent( return result class LTXMultiScalePipeline: + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile( + spatial_upscaler_model_name_or_path + ): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir= "/mnt/disks/diffusionproj", + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents( self, latest_upsampler: LatentUpsampler, latents: torch.Tensor ): @@ -1309,37 +1268,29 @@ def _upsample_latents( return upsampled_latents def __init__( - self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + self, video_pipeline: LTXVideoPipeline ): self.video_pipeline = video_pipeline self.vae = video_pipeline.vae - self.latent_upsampler = latent_upsampler - + def __call__( self, height, width, num_frames, - is_video, output_type, generator, config ) -> Any: + latent_upsampler = self.load_latent_upsampler(config) original_output_type = output_type - original_width = width - original_height = height - x_width = int(width * config.downscale_factor) - downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) - x_height = int(height * config.downscale_factor) - downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) - #use original height and width here to see output_type = 'latent' - result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, + result = self.video_pipeline(height=height, width=width, num_frames=num_frames, is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], - num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) + num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_block_list=config.first_pass["skip_block_list"]) latents = result - upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = self._upsample_latents(latent_upsampler, latents) upsampled_latents = adain_filter_latent( latents=upsampled_latents, reference_latents=latents ) @@ -1348,12 +1299,11 @@ def __call__( latents = upsampled_latents output_type = original_output_type - width = downscaled_width * 2 - height = downscaled_height * 2 + - result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, + result = self.video_pipeline(height=height*2, width=width*2, num_frames=num_frames, is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], - num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) + num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_block_list=config.second_pass["skip_block_list"]) if original_output_type != "latent": num_frames = result.shape[2] @@ -1361,7 +1311,7 @@ def __call__( videos = F.interpolate( videos, - size=(original_height, original_width), + size=(height, width), mode="bilinear", align_corners=False, ) From 972e316601fc896abfea9b9f81d49ce77ee037e8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 23:19:45 +0000 Subject: [PATCH 44/69] initial cleaning --- src/maxdiffusion/configs/ltx_video.yml | 10 +- src/maxdiffusion/generate_ltx_video.py | 50 +- .../models/ltx_video/repeatable_layer.py | 15 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 485 +++++++----------- 4 files changed, 191 insertions(+), 369 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index cba635a1a..e54216c72 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -32,18 +32,10 @@ flow_shift: 5.0 downscale_factor: 0.6666666 spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" prompt_enhancement_words_threshold: 120 -# guidance_scale: [1, 1, 6, 8, 6, 1, 1] #4.5 -# stg_scale: [0, 0, 4, 4, 4, 2, 1] #1.0 -# rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] #0.7 -# num_inference_steps: 30 -# skip_final_inference_steps: 3 -# skip_initial_inference_steps: 0 -# guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] -# skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] stg_mode: "attention_values" decode_timestep: 0.05 decode_noise_scale: 0.025 -# cfg_star_rescale: True +models_dir: "/mnt/disks/diffusionproj" #where safetensor file is first_pass: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 90c82747c..9ad816564 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,16 +4,12 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -import jax.numpy as jnp -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from huggingface_hub import hf_hub_download import imageio from datetime import datetime -from maxdiffusion.utils import export_to_video import os -import json import torch from pathlib import Path @@ -96,52 +92,12 @@ def run(config): num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 padding = calculate_padding( config.height, config.width, height_padded, width_padded) - # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold - # prompt_word_count = len(config.prompt.split()) - # enhance_prompt = ( - # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold - # ) - seed = 10 # change this, generator in pytorch, used in prepare_latents + seed = 10 generator = torch.Generator().manual_seed(seed) pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) - if config.pipeline_type == "multi-scale": #move this to pipeline file?? - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile( - spatial_upscaler_model_name_or_path - ): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir= "/mnt/disks/diffusionproj", - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = create_latent_upsampler( - spatial_upscaler_model_path, "cpu" #device set to cpu for now - ) - pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) - stg_mode = config.stg_mode - if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": - skip_layer_strategy = SkipLayerStrategy.AttentionValues - elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": - skip_layer_strategy = SkipLayerStrategy.AttentionSkip - elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": - skip_layer_strategy = SkipLayerStrategy.Residual - elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": - skip_layer_strategy = SkipLayerStrategy.TransformerBlock - else: - raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") - # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, - # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, - # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images - images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config) + pipeline = LTXMultiScalePipeline(pipeline) + images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config) (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom pad_right = -pad_right diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 5e2ceb789..06b367ce4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -25,8 +25,7 @@ def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tup mod = self.module(*self.module_init_args, **self.module_init_kwargs) - # block_args are the static arguments passed to each individual block - output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers next_index = index_input + 1 new_carry = (output_data, next_index) @@ -76,14 +75,14 @@ class RepeatableLayer(nn.Module): """ @nn.compact - def __call__(self, *args): # args is now the full input to RepeatableLayer + def __call__(self, *args): if not args: raise ValueError("RepeatableLayer expects at least one argument for initial data input.") - initial_data_input = args[0] # The first element is your main data input - static_block_args = args[1:] # Any subsequent elements are static args for each block + initial_data_input = args[0] + static_block_args = args[1:] - initial_index = jnp.array(0, dtype=jnp.int32) + initial_index = jnp.array(0, dtype=jnp.int32) #index of current transformer block scan_kwargs = {} if self.pspec_name is not None: @@ -92,9 +91,6 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer initializing = self.is_mutable_collection("params") params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - # in_axes for the scanned function (RepeatableCarryBlock.__call__): - # 1. The 'carry' tuple ((0, 0)) - # 2. Then, nn.broadcast for each of the `static_block_args` in_axes_for_scan = (nn.broadcast,) * (len(args)-1) scan_fn = nn.scan( @@ -117,5 +113,4 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer # Call wrapped_function with the initial carry tuple and the static_block_args (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) - # Typically, you only want the final data output from the sequence of layers return final_data \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 35750ad5e..5584fcc0f 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -117,21 +117,6 @@ def __init__( @classmethod def load_scheduler(cls, ckpt_path, config): - # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - # "Wan-AI/Wan2.1-T2V-14B-Diffusers", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( - # "Lightricks/LTX-Video", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # import pdb; pdb.set_trace() - # scheduler = FlaxRectifiedFlowMultistepScheduler( - # sampler="LinearQuadratic" - # ) - # scheduler_state = scheduler.create_state() if config.sampler == "from_checkpoint" or not config.sampler: scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) else: @@ -159,14 +144,14 @@ def load_transformer(cls, config): for name in ignored_keys: if name in model_config: del model_config[name] + transformer = Transformer3DModel( - # change this sharding back **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) weights_init_fn = functools.partial( transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) - import pdb; pdb.set_trace() + ##load in jax weights checkpoint absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) @@ -194,17 +179,11 @@ def load_vae(cls, ckpt_path): @classmethod def load_text_encoder(cls, ckpt_path): - # text_encoder = T5EncoderModel.from_pretrained( - # ckpt_path, subfolder="text_encoder" - # ) t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) return t5_encoder @classmethod def load_tokenizer(cls, config, ckpt_path): - # tokenizer = T5Tokenizer.from_pretrained( - # ckpt_path, subfolder="tokenizer" - # ) t5_tokenizer = AutoTokenizer.from_pretrained( ckpt_path, max_length=config.max_sequence_length, use_fast=True ) @@ -235,7 +214,7 @@ def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): config) # load from pytorch version - models_dir = "/mnt/disks/diffusionproj" #edit this! + models_dir = config.models_dir ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" if not os.path.isfile(ltxv_model_name_or_path): ltxv_model_path = hf_hub_download( @@ -292,6 +271,7 @@ def process(text: str): return text return [process(t) for t in text] + def denoising_step( scheduler, latents: Array, @@ -303,14 +283,6 @@ def denoising_step( t_eps: float = 1e-6, stochastic_sampling: bool = False, ) -> Array: - """ - Perform the denoising step for the required tokens, based on the current timestep and - conditioning mask: - Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) - and will start to be denoised when the current timestep is equal or lower than their - conditioning timestep. - (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) - """ # Denoise the latents using the scheduler denoised_latents = scheduler.step( noise_pred, @@ -343,7 +315,7 @@ def retrieve_timesteps( #currently doesn't support custom timesteps skip_initial_inference_steps < 0 or skip_final_inference_steps < 0 or skip_initial_inference_steps + skip_final_inference_steps - >= num_inference_steps # Use the original num_inference_steps here for the check + >= num_inference_steps ): raise ValueError( "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" @@ -378,13 +350,13 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] max_length = ( - text_encoder_max_tokens # TPU supports only lengths multiple of 128 + text_encoder_max_tokens ) if prompt_embeds is None: assert ( self.text_encoder is not None ), "You should provide either prompt_embeds or self.text_encoder should not be None," - # text_enc_device = next(self.text_encoder.parameters()) + prompt = self._text_preprocessing(prompt) text_inputs = self.tokenizer( prompt, @@ -419,25 +391,15 @@ def encode_prompt( else: dtype = None bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - print(isinstance(prompt_embeds, Array)) - # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - # prompt_embeds = prompt_embeds.view( - # bs_embed * num_images_per_prompt, seq_len, -1 - # ) prompt_embeds = jnp.reshape( prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) prompt_attention_mask = jnp.tile( prompt_attention_mask, (1, num_images_per_prompt)) - # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) - # prompt_attention_mask = prompt_attention_mask.view( - # bs_embed * num_images_per_prompt, -1 - # ) prompt_attention_mask = jnp.reshape( prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - # get unconditional embeddings for classifier free guidance hasn't changed yet + # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = self._text_preprocessing(negative_prompt) uncond_tokens = uncond_tokens * batch_size @@ -460,7 +422,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = jnp.tile(negative_prompt_embeds, @@ -481,13 +442,13 @@ def encode_prompt( negative_prompt_attention_mask = None return ( - prompt_embeds,#(1, 256, 4096) - prompt_attention_mask, #1, 256 + prompt_embeds, + prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) - def prepare_latents( + def prepare_latents( ## this is in pytorch self, latents: torch.Tensor | None, media_items: torch.Tensor | None, @@ -556,123 +517,123 @@ def prepare_conditioning( assert isinstance(self.vae, CausalVideoAutoencoder) - # if conditioning_items: - # batch_size, _, num_latent_frames = init_latents.shape[:3] - - # init_conditioning_mask = torch.zeros( - # init_latents[:, 0, :, :, :].shape, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents = [] - # extra_conditioning_pixel_coords = [] - # extra_conditioning_mask = [] - # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) - - # # Process each conditioning item - # for conditioning_item in conditioning_items: - # conditioning_item = self._resize_conditioning_item( - # conditioning_item, height, width - # ) - # media_item = conditioning_item.media_item - # media_frame_number = conditioning_item.media_frame_number - # strength = conditioning_item.conditioning_strength - # assert media_item.ndim == 5 # (b, c, f, h, w) - # b, c, n_frames, h, w = media_item.shape - # assert ( - # height == h and width == w - # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - # assert n_frames % 8 == 1 - # assert ( - # media_frame_number >= 0 - # and media_frame_number + n_frames <= num_frames - # ) - - # # Encode the provided conditioning media item - # media_item_latents = vae_encode( - # media_item.to(dtype=self.vae.dtype, device=self.vae.device), - # self.vae, - # vae_per_channel_normalize=vae_per_channel_normalize, - # ).to(dtype=init_latents.dtype) - - # # Handle the different conditioning cases - # if media_frame_number == 0: - # # Get the target spatial position of the latent conditioning item - # media_item_latents, l_x, l_y = self._get_latent_spatial_position( - # media_item_latents, - # conditioning_item, - # height, - # width, - # strip_latent_border=True, - # ) - # b, c_l, f_l, h_l, w_l = media_item_latents.shape - - # # First frame or sequence - just update the initial noise latents and the mask - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( - # torch.lerp( - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], - # media_item_latents, - # strength, - # ) - # ) - # init_conditioning_mask[ - # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l - # ] = strength - # else: - # # Non-first frame or sequence - # if n_frames > 1: - # # Handle non-first sequence. - # # Encoded latents are either fully consumed, or the prefix is handled separately below. - # ( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # ) = self._handle_non_first_conditioning_sequence( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # media_frame_number, - # strength, - # ) - - # # Single frame or sequence-prefix latents - # if media_item_latents is not None: - # noise = randn_tensor( - # media_item_latents.shape, - # generator=generator, - # device=media_item_latents.device, - # dtype=media_item_latents.dtype, - # ) - - # media_item_latents = torch.lerp( - # noise, media_item_latents, strength - # ) - - # # Patchify the extra conditioning latents and calculate their pixel coordinates - # media_item_latents, latent_coords = self.patchifier.patchify( - # latents=media_item_latents - # ) - # pixel_coords = latent_to_pixel_coords( - # latent_coords, - # self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, - # ) - - # # Update the frame numbers to match the target frame number - # pixel_coords[:, 0] += media_frame_number - # extra_conditioning_num_latents += media_item_latents.shape[1] - - # conditioning_mask = torch.full( - # media_item_latents.shape[:2], - # strength, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents.append(media_item_latents) - # extra_conditioning_pixel_coords.append(pixel_coords) - # extra_conditioning_mask.append(conditioning_mask) + if conditioning_items: + batch_size, _, num_latent_frames = init_latents.shape[:3] + + init_conditioning_mask = torch.zeros( + init_latents[:, 0, :, :, :].shape, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents = [] + extra_conditioning_pixel_coords = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # Process each conditioning item + for conditioning_item in conditioning_items: + conditioning_item = self._resize_conditioning_item( + conditioning_item, height, width + ) + media_item = conditioning_item.media_item + media_frame_number = conditioning_item.media_frame_number + strength = conditioning_item.conditioning_strength + assert media_item.ndim == 5 # (b, c, f, h, w) + b, c, n_frames, h, w = media_item.shape + assert ( + height == h and width == w + ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + assert n_frames % 8 == 1 + assert ( + media_frame_number >= 0 + and media_frame_number + n_frames <= num_frames + ) + + # Encode the provided conditioning media item + media_item_latents = vae_encode( + media_item.to(dtype=self.vae.dtype, device=self.vae.device), + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ).to(dtype=init_latents.dtype) + + # Handle the different conditioning cases + if media_frame_number == 0: + # Get the target spatial position of the latent conditioning item + media_item_latents, l_x, l_y = self._get_latent_spatial_position( + media_item_latents, + conditioning_item, + height, + width, + strip_latent_border=True, + ) + b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # First frame or sequence - just update the initial noise latents and the mask + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + torch.lerp( + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + media_item_latents, + strength, + ) + ) + init_conditioning_mask[ + :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + ] = strength + else: + # Non-first frame or sequence + if n_frames > 1: + # Handle non-first sequence. + # Encoded latents are either fully consumed, or the prefix is handled separately below. + ( + init_latents, + init_conditioning_mask, + media_item_latents, + ) = self._handle_non_first_conditioning_sequence( + init_latents, + init_conditioning_mask, + media_item_latents, + media_frame_number, + strength, + ) + + # Single frame or sequence-prefix latents + if media_item_latents is not None: + noise = randn_tensor( + media_item_latents.shape, + generator=generator, + device=media_item_latents.device, + dtype=media_item_latents.dtype, + ) + + media_item_latents = torch.lerp( + noise, media_item_latents, strength + ) + + # Patchify the extra conditioning latents and calculate their pixel coordinates + media_item_latents, latent_coords = self.patchifier.patchify( + latents=media_item_latents + ) + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.vae, + causal_fix=self.transformer.config.causal_temporal_positioning, + ) + + # Update the frame numbers to match the target frame number + pixel_coords[:, 0] += media_frame_number + extra_conditioning_num_latents += media_item_latents.shape[1] + + conditioning_mask = torch.full( + media_item_latents.shape[:2], + strength, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents.append(media_item_latents) + extra_conditioning_pixel_coords.append(pixel_coords) + extra_conditioning_mask.append(conditioning_mask) # Patchify the updated latents and calculate their pixel coordinates init_latents, init_latent_coords = self.patchifier.patchify( @@ -683,46 +644,45 @@ def prepare_conditioning( self.vae, # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now causal_fix=True - ) if not conditioning_items: return init_latents, init_pixel_coords, None, 0 - # init_conditioning_mask, _ = self.patchifier.patchify( - # latents=init_conditioning_mask.unsqueeze(1) - # ) - # init_conditioning_mask = init_conditioning_mask.squeeze(-1) - - # if extra_conditioning_latents: - # # Stack the extra conditioning latents, pixel coordinates and mask - # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) - # init_pixel_coords = torch.cat( - # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 - # ) - # init_conditioning_mask = torch.cat( - # [*extra_conditioning_mask, init_conditioning_mask], dim=1 - # ) - - # if self.transformer.use_tpu_flash_attention: - # # When flash attention is used, keep the original number of tokens by removing - # # tokens from the end. - # init_latents = init_latents[:, :-extra_conditioning_num_latents] - # init_pixel_coords = init_pixel_coords[ - # :, :, :-extra_conditioning_num_latents - # ] - # init_conditioning_mask = init_conditioning_mask[ - # :, :-extra_conditioning_num_latents - # ] - - # return ( - # init_latents, - # init_pixel_coords, - # init_conditioning_mask, - # extra_conditioning_num_latents, - # ) - - # change the paramters of these, currently pass in dummy inputs + init_conditioning_mask, _ = self.patchifier.patchify( + latents=init_conditioning_mask.unsqueeze(1) + ) + init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + if extra_conditioning_latents: + # Stack the extra conditioning latents, pixel coordinates and mask + init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + init_pixel_coords = torch.cat( + [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + ) + init_conditioning_mask = torch.cat( + [*extra_conditioning_mask, init_conditioning_mask], dim=1 + ) + + if self.transformer.use_tpu_flash_attention: + # When flash attention is used, keep the original number of tokens by removing + # tokens from the end. + init_latents = init_latents[:, :-extra_conditioning_num_latents] + init_pixel_coords = init_pixel_coords[ + :, :, :-extra_conditioning_num_latents + ] + init_conditioning_mask = init_conditioning_mask[ + :, :-extra_conditioning_num_latents + ] + + return ( + init_latents, + init_pixel_coords, + init_conditioning_mask, + extra_conditioning_num_latents, + ) + + def __call__( self, @@ -881,8 +841,7 @@ def __call__( prompt_attention_mask_batch = prompt_attention_mask if do_classifier_free_guidance: prompt_embeds_batch = jnp.concatenate( - [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension - ) + [negative_prompt_embeds, prompt_embeds], axis=0) prompt_attention_mask_batch = jnp.concatenate( [negative_prompt_attention_mask, prompt_attention_mask], axis=0 ) @@ -898,7 +857,7 @@ def __call__( latents = self.prepare_latents( latents=latents, media_items=None, # set to None - timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this + timestep=scheduler_state.timesteps[0], latent_shape=latent_shape, dtype=None, device=None, @@ -925,73 +884,10 @@ def __call__( fractional_coords = pixel_coords.to(torch.float32) fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - # if not isinstance(guidance_scale, List): - # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) - # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - # latents = jax.device_put(example_inputs["latents"], data_sharding) - # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) noise_cond = jnp.ones( # initialize first round with this! (1, 1) ) - - # # noise_cond = None - # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" - # tensor_dict = torch.load(saved_tensor_path) - - # for key, value in tensor_dict.items(): - # if value is not None: - # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) - # example_inputs = tensor_dict - # latents = jax.device_put(example_inputs["latent_model_input"]) - # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) - # fractional_coords = jax.device_put(example_inputs["indices_grid"]) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) - # segment_ids = None - # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) - - # #only run this for the first time! - # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) - # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken - # scheduler_state = self.scheduler_state - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = jax.jit( - # functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # ), - # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct - # out_shardings=None, - # ) - segment_ids = None - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # # guidance_scale=guidance_scale - # ) + segment_ids = None #how is this created? p_run_inference = functools.partial( run_inference, transformer=self.transformer, @@ -1018,7 +914,6 @@ def __call__( with self.mesh: latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( - # add scheduler state back in torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) latents = torch.from_numpy(np.array(latents)) latents = latents[:, num_cond_latents:] @@ -1061,7 +956,7 @@ def __call__( timestep=decode_timestep, ) image = self.image_processor.postprocess( - image, output_type=output_type) # shape mismatch here + image, output_type=output_type) else: image = latents @@ -1072,7 +967,7 @@ def __call__( return (image,) return image - # save states here + def transformer_forward_pass( # need to jit this? wan didnt @@ -1102,21 +997,13 @@ def transformer_forward_pass( # need to jit this? wan didnt def run_inference( transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale ): - # do_classifier_free_guidance = guidance_scale > 1.0 - # for step in range(num_inference_steps): for i, t in enumerate(scheduler_state.timesteps): current_timestep = t latent_model_input = ( jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents ) - # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) - # timestep = jnp.broadcast_to( - # t, (latent_model_input.shape[0],) - # ).reshape(-1, 1) if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - # Determine the correct dtype based on the device (similar to PyTorch) - is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. + is_mps = False if isinstance(current_timestep, float): dtype = jnp.float32 else: @@ -1129,49 +1016,43 @@ def run_inference( elif current_timestep.ndim == 0: current_timestep = jnp.expand_dims(current_timestep, axis=0) - # Broadcast to batch dimension (compatible with ONNX/Core ML) + # Broadcast to batch dimension current_timestep = jnp.broadcast_to( current_timestep, (latent_model_input.shape[0],1) ) - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( - skip_layer_masks[i] - if skip_layer_masks is not None - else None - )) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( + skip_layer_masks[i] + if skip_layer_masks is not None + else None + )) # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) - # # latents = self.denoising - # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated + if do_spatio_temporal_guidance: - # JAX uses jnp.split for splitting along an axis. - # Equivalent to .chunk() for equal-sized chunks chunks = jnp.split(noise_pred, num_conds, axis=0) noise_pred_text = chunks[-2] noise_pred_text_perturb = chunks[-1] if do_classifier_free_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) noise_pred_uncond = chunks[0] noise_pred_text = chunks[1] if cfg_star_rescale: positive_flat = noise_pred_text.reshape(batch_size, -1) negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum( #(1, 1) + dot_product = jnp.sum( positive_flat * negative_flat, axis=1, keepdims=True ) - squared_norm = ( #(1, 1) + squared_norm = ( jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 ) - alpha = dot_product / squared_norm #might need to reshape this (1, 1) + alpha = dot_product / squared_norm alpha = alpha.reshape(batch_size, 1, 1) - noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) + noise_pred_uncond = alpha * noise_pred_uncond noise_pred = noise_pred_uncond + guidance_scale[i] * ( noise_pred_text - noise_pred_uncond ) @@ -1191,7 +1072,7 @@ def run_inference( noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] # JAX slicing is similar + current_timestep = current_timestep[:1] latents, scheduler_state = scheduler.step( scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() @@ -1295,8 +1176,6 @@ def __call__( latents=upsampled_latents, reference_latents=latents ) - - latents = upsampled_latents output_type = original_output_type From c37547102158920a1c4197e381156c83c9b2afd8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 18:47:16 +0000 Subject: [PATCH 45/69] multiscale pipeline --- .../checkpointing/checkpointing_utils.py | 2 +- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 202 +- src/maxdiffusion/max_utils.py | 2 +- .../ltx_video/autoencoders/causal_conv3d.py | 63 - .../autoencoders/causal_video_autoencoder.py | 1398 ----------- .../ltx_video/autoencoders/conv_nd_factory.py | 90 - .../ltx_video/autoencoders/dual_conv3d.py | 217 -- .../autoencoders/latent_upsampler.py | 203 -- .../ltx_video/autoencoders/pixel_norm.py | 12 - .../ltx_video/autoencoders/pixel_shuffle.py | 33 - .../models/ltx_video/autoencoders/vae.py | 380 --- .../ltx_video/autoencoders/vae_encode.py | 247 -- .../autoencoders/video_autoencoder.py | 1045 --------- .../{ => models}/autoencoders/__init__.py | 0 .../models/ltx_video/repeatable_layer.py | 130 +- .../ltx_video/transformers/attention.py | 1689 +++++++------- .../transformers/symmetric_patchifier.py | 84 - .../ltx_video/transformers/transformer3d.py | 580 +++-- .../utils/diffusers_config_mapping.py | 174 -- .../ltx_video/utils/prompt_enhance_utils.py | 226 -- .../ltx_video/utils/skip_layer_strategy.py | 8 - .../models/ltx_video/utils/torch_utils.py | 25 - .../pipelines/ltx_video/ltx_video_pipeline.py | 2062 ++++++++--------- .../schedulers/scheduling_rectified_flow.py | 422 ++-- 25 files changed, 2426 insertions(+), 6870 deletions(-) delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py rename src/maxdiffusion/models/ltx_video/{ => models}/autoencoders/__init__.py (100%) delete mode 100644 src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_utils.py diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 5072b4639..b83e85a87 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -217,7 +217,7 @@ def load_state_if_possible( return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) else: item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index e54216c72..0fdbe7f9f 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -24,7 +24,7 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage." +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." height: 512 width: 512 num_frames: 88 #344 diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 9ad816564..f7d7e6d03 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,11 +4,8 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler -from huggingface_hub import hf_hub_download import imageio from datetime import datetime - import os import torch from pathlib import Path @@ -18,52 +15,45 @@ def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: - # Calculate total padding needed - pad_height = target_height - source_height - pad_width = target_width - source_width + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width - # Calculate padding for each side - pad_top = pad_height // 2 - pad_bottom = pad_height - pad_top # Handles odd padding - pad_left = pad_width // 2 - pad_right = pad_width - pad_left # Handles odd padding + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding - # Return padded tensor - # Padding format is (left, right, top, bottom) - padding = (pad_left, pad_right, pad_top, pad_bottom) - return padding + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: - # Remove non-letters and convert to lowercase - clean_text = "".join( - char.lower() for char in text if char.isalpha() or char.isspace() - ) + # Remove non-letters and convert to lowercase + clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) - # Split into words - words = clean_text.split() + # Split into words + words = clean_text.split() - # Build result string keeping track of length - result = [] - current_length = 0 + # Build result string keeping track of length + result = [] + current_length = 0 - for word in words: - # Add word length plus 1 for underscore (except for first word) - new_length = current_length + len(word) + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) - if new_length <= max_len: - result.append(word) - current_length += len(word) - else: - break + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break - return "-".join(result) + return "-".join(result) -def create_latent_upsampler(latent_upsampler_model_path: str, device: str): - latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) - latent_upsampler.to(device) - latent_upsampler.eval() - return latent_upsampler def get_unique_filename( base: str, @@ -75,78 +65,82 @@ def get_unique_filename( endswith=None, index_range=1000, ) -> Path: - base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" - for i in range(index_range): - filename = dir / \ - f"{base_filename}_{i}{endswith if endswith else ''}{ext}" - if not os.path.exists(filename): - return filename - raise FileExistsError( - f"Could not find a unique filename after {index_range} attempts." - ) + base_filename = ( + f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + ) + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.") def run(config): - height_padded = ((config.height - 1) // 32 + 1) * 32 - width_padded = ((config.width - 1) // 32 + 1) * 32 - num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 - padding = calculate_padding( - config.height, config.width, height_padded, width_padded) - - seed = 10 - generator = torch.Generator().manual_seed(seed) - pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) - pipeline = LTXMultiScalePipeline(pipeline) - images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config) - (pad_left, pad_right, pad_top, pad_bottom) = padding - pad_bottom = -pad_bottom - pad_right = -pad_right - if pad_bottom == 0: - pad_bottom = images.shape[3] - if pad_right == 0: - pad_right = images.shape[4] - images = images[:, :, :config.num_frames, - pad_top:pad_bottom, pad_left:pad_right] - output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") - output_dir.mkdir(parents=True, exist_ok=True) - for i in range(images.shape[0]): - # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C - video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() - # Unnormalizing images to [0, 255] range - video_np = (video_np * 255).astype(np.uint8) - fps = config.frame_rate - height, width = video_np.shape[1:3] - # In case a single image is generated - if video_np.shape[0] == 1: - output_filename = get_unique_filename( - f"image_output_{i}", - ".png", - prompt=config.prompt, - seed=seed, - resolution=(height, width, config.num_frames), - dir=output_dir, - ) - imageio.imwrite(output_filename, video_np[0]) - else: - output_filename = get_unique_filename( - f"video_output_{i}", - ".mp4", - prompt=config.prompt, - seed=seed, - resolution=(height, width, config.num_frames), - dir=output_dir, - ) - print(output_filename) - # Write video - with imageio.get_writer(output_filename, fps=fps) as video: - for frame in video_np: - video.append_data(frame) + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding(config.height, config.width, height_padded, width_padded) + + seed = 10 + generator = torch.Generator().manual_seed(seed) + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False) + pipeline = LTXMultiScalePipeline(pipeline) + images = pipeline( + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + output_type="pt", + generator=generator, + config=config, + ) + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + print(output_filename) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index e13f31f94..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -612,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py deleted file mode 100644 index 98249c2f5..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Tuple, Union - -import torch -import torch.nn as nn - - -class CausalConv3d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size: int = 3, - stride: Union[int, Tuple[int]] = 1, - dilation: int = 1, - groups: int = 1, - spatial_padding_mode: str = "zeros", - **kwargs, - ): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - - kernel_size = (kernel_size, kernel_size, kernel_size) - self.time_kernel_size = kernel_size[0] - - dilation = (dilation, 1, 1) - - height_pad = kernel_size[1] // 2 - width_pad = kernel_size[2] // 2 - padding = (0, height_pad, width_pad) - - self.conv = nn.Conv3d( - in_channels, - out_channels, - kernel_size, - stride=stride, - dilation=dilation, - padding=padding, - padding_mode=spatial_padding_mode, - groups=groups, - ) - - def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x - - @property - def weight(self): - return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py deleted file mode 100644 index 1255b6d34..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py +++ /dev/null @@ -1,1398 +0,0 @@ -import json -import os -from functools import partial -from types import SimpleNamespace -from typing import Any, Mapping, Optional, Tuple, Union, List -from pathlib import Path - -import torch -import numpy as np -from einops import rearrange -from torch import nn -from diffusers.utils import logging -import torch.nn.functional as F -from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings -from safetensors import safe_open - - -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm -from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND -from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper -from maxdiffusion.models.ltx_video.transformers.attention import Attention -from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( - diffusers_and_ours_config_mapping, - make_hashable_key, - VAE_KEYS_RENAME_DICT, -) - -PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class CausalVideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_name_or_path = Path(pretrained_model_name_or_path) - if ( - pretrained_model_name_or_path.is_dir() - and (pretrained_model_name_or_path / "autoencoder.pth").exists() - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - std_of_means = data_dict["std-of-means"] - mean_of_means = data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( - std_of_means - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( - mean_of_means - ) - - elif pretrained_model_name_or_path.is_dir(): - config_path = pretrained_model_name_or_path / "vae" / "config.json" - with open(config_path, "r") as f: - config = make_hashable_key(json.load(f)) - - assert config in diffusers_and_ours_config_mapping, ( - "Provided diffusers checkpoint config for VAE is not suppported. " - "We only support diffusers configs found in Lightricks/LTX-Video." - ) - - config = diffusers_and_ours_config_mapping[config] - - state_dict_path = ( - pretrained_model_name_or_path - / "vae" - / "diffusion_pytorch_model.safetensors" - ) - - state_dict = {} - with safe_open(state_dict_path, framework="pt", device="cpu") as f: - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - for key in list(state_dict.keys()): - new_key = key - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - - state_dict[new_key] = state_dict.pop(key) - - elif pretrained_model_name_or_path.is_file() and str( - pretrained_model_name_or_path - ).endswith(".safetensors"): - state_dict = {} - with safe_open( - pretrained_model_name_or_path, framework="pt", device="cpu" - ) as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - configs = json.loads(metadata["config"]) - config = configs["vae"] - - video_vae = cls.from_config(config) - if "torch_dtype" in kwargs: - video_vae.to(kwargs["torch_dtype"]) - video_vae.load_state_dict(state_dict) - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "CausalVideoAutoencoder" - ), "config must have _class_name=CausalVideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - normalize_latent_channels = config.get("normalize_latent_channels", False) - - if use_quant_conv and latent_log_var in ["uniform", "constant"]: - raise ValueError( - f"latent_log_var={latent_log_var} requires use_quant_conv=False" - ) - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - base_channels=config.get("encoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - causal=config.get("causal_decoder", False), - timestep_conditioning=config.get("timestep_conditioning", False), - base_channels=config.get("decoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - dims = config["dims"] - return CausalVideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - normalize_latent_channels=normalize_latent_channels, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="CausalVideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, - out_channels=self.decoder.conv_out.out_channels - // self.decoder.patch_size**2, - latent_channels=self.decoder.conv_in.in_channels, - encoder_blocks=self.encoder.blocks_desc, - decoder_blocks=self.decoder.blocks_desc, - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - causal_decoder=self.decoder.causal, - timestep_conditioning=self.decoder.timestep_conditioning, - normalize_latent_channels=self.normalize_latent_channels, - ) - - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def spatial_downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.encoder.blocks_desc - if block[0] - in [ - "compress_space", - "compress_all", - "compress_all_res", - "compress_space_res", - ] - ] - ) - * self.encoder.patch_size - ) - - @property - def temporal_downscale_factor(self): - return 2 ** len( - [ - block - for block in self.encoder.blocks_desc - if block[0] - in [ - "compress_time", - "compress_all", - "compress_all_res", - "compress_time_res", - ] - ] - ) - - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - if any([key.startswith("vae.") for key in state_dict.keys()]): - state_dict = { - key.replace("vae.", ""): value - for key, value in state_dict.items() - if key.startswith("vae.") - } - ckpt_state_dict = { - key: value - for key, value in state_dict.items() - if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - - model_keys = set(name for name, _ in self.named_modules()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - converted_state_dict = {} - for key, value in ckpt_state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - key_prefix = ".".join(key.split(".")[:-1]) - if "norm" in key and key_prefix not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - data_dict = { - key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value - for key, value in state_dict.items() - if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - if len(data_dict) > 0: - self.register_buffer("std_of_means", data_dict["std-of-means"]) - self.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) - - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer - - def set_use_tpu_flash_attention(self): - for block in self.decoder.up_blocks: - if isinstance(block, UNetMidBlock3D) and block.attention_blocks: - for attention_block in block.attention_blocks: - attention_block.set_use_tpu_flash_attention() - - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - self.blocks_desc = blocks - - in_channels = in_channels * patch_size**2 - output_channel = base_channels - - self.conv_in = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.down_blocks = nn.ModuleList([]) - - for block_name, block_params in blocks: - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 1, 1), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(1, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown block: {block_name}") - - self.down_blocks.append(block) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var == "constant": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, - output_channel, - conv_out_channels, - 3, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)(sample) - - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - elif self.latent_log_var == "constant": - sample = sample[:, :-1, ...] - approx_ln_0 = ( - -30 - ) # this is the minimal clamp value in DiagonalGaussianDistribution objects - sample = torch.cat( - [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], - dim=1, - ) - - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - causal (`bool`, *optional*, defaults to `True`): - Whether to use causal convolutions or not. - """ - - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - causal: bool = True, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.layers_per_block = layers_per_block - out_channels = out_channels * patch_size**2 - self.causal = causal - self.blocks_desc = blocks - - # Compute output channel to be product of all channel-multiplier blocks - output_channel = base_channels - for block_name, block_params in list(reversed(blocks)): - block_params = block_params if isinstance(block_params, dict) else {} - if block_name == "res_x_y": - output_channel = output_channel * block_params.get("multiplier", 2) - if block_name.startswith("compress"): - output_channel = output_channel * block_params.get("multiplier", 1) - - self.conv_in = make_conv_nd( - dims, - in_channels, - output_channel, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.up_blocks = nn.ModuleList([]) - - for block_name, block_params in list(reversed(blocks)): - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = output_channel // block_params.get("multiplier", 2) - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=False, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - output_channel = output_channel // block_params.get("multiplier", 1) - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 2, 2), - residual=block_params.get("residual", False), - out_channels_reduction_factor=block_params.get("multiplier", 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown layer: {block_name}") - - self.up_blocks.append(block) - - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, - output_channel, - out_channels, - 3, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.gradient_checkpointing = False - - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter( - torch.tensor(1000.0, dtype=torch.float32) - ) - self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - output_channel * 2, 0 - ) - self.last_scale_shift_table = nn.Parameter( - torch.randn(2, output_channel) / output_channel**0.5 - ) - - def forward( - self, - sample: torch.FloatTensor, - target_shape, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - batch_size = sample.shape[0] - - sample = self.conv_in(sample, causal=self.causal) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - sample = sample.to(upscale_dtype) - - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: - embedded_timestep = self.last_time_embedder( - timestep=scaled_timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=sample.shape[0], - hidden_dtype=sample.dtype, - ) - embedded_timestep = embedded_timestep.view( - batch_size, embedded_timestep.shape[-1], 1, 1, 1 - ) - ada_values = self.last_scale_shift_table[ - None, ..., None, None, None - ] + embedded_timestep.reshape( - batch_size, - 2, - -1, - embedded_timestep.shape[-3], - embedded_timestep.shape[-2], - embedded_timestep.shape[-1], - ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift - - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) - - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - - return sample - - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - inject_noise (`bool`, *optional*, defaults to `False`): - Whether to inject noise into the hidden states. - timestep_conditioning (`bool`, *optional*, defaults to `False`): - Whether to condition the hidden states on the timestep. - attention_head_dim (`int`, *optional*, defaults to -1): - The dimension of the attention head. If -1, no attention is used. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - attention_head_dim: int = -1, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - in_channels * 4, 0 - ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - inject_noise=inject_noise, - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - for _ in range(num_layers) - ] - ) - - self.attention_blocks = None - - if attention_head_dim > 0: - if attention_head_dim > in_channels: - raise ValueError( - "attention_head_dim must be less than or equal to in_channels" - ) - - self.attention_blocks = nn.ModuleList( - [ - Attention( - query_dim=in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - bias=True, - out_bias=True, - qk_norm="rms_norm", - residual_connection=True, - ) - for _ in range(num_layers) - ] - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - timestep_embed = None - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - batch_size = hidden_states.shape[0] - timestep_embed = self.time_embedder( - timestep=timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - timestep_embed = timestep_embed.view( - batch_size, timestep_embed.shape[-1], 1, 1, 1 - ) - - if self.attention_blocks: - for resnet, attention in zip(self.res_blocks, self.attention_blocks): - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - # Reshape the hidden states to be (batch_size, frames * height * width, channel) - batch_size, channel, frames, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, frames * height * width - ).transpose(1, 2) - - if attention.use_tpu_flash_attention: - # Pad the second dimension to be divisible by block_k_major (block in flash attention) - seq_len = hidden_states.shape[1] - block_k_major = 512 - pad_len = (block_k_major - seq_len % block_k_major) % block_k_major - if pad_len > 0: - hidden_states = F.pad( - hidden_states, (0, 0, 0, pad_len), "constant", 0 - ) - - # Create a mask with ones for the original sequence length and zeros for the padded indexes - mask = torch.ones( - (hidden_states.shape[0], seq_len), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if pad_len > 0: - mask = F.pad(mask, (0, pad_len), "constant", 0) - - hidden_states = attention( - hidden_states, - attention_mask=( - None if not attention.use_tpu_flash_attention else mask - ), - ) - - if attention.use_tpu_flash_attention: - # Remove the padding - if pad_len > 0: - hidden_states = hidden_states[:, :-pad_len, :] - - # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, frames, height, width - ) - else: - for resnet in self.res_blocks: - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - return hidden_states - - -class SpaceToDepthDownsample(nn.Module): - def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): - super().__init__() - self.stride = stride - self.group_size = in_channels * np.prod(stride) // out_channels - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels // np.prod(stride), - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - def forward(self, x, causal: bool = True): - if self.stride[0] == 2: - x = torch.cat( - [x[:, :, :1, :, :], x], dim=2 - ) # duplicate first frames for padding - - # skip connection - x_in = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) - x_in = x_in.mean(dim=2) - - # conv - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - - x = x + x_in - - return x - - -class DepthToSpaceUpsample(nn.Module): - def __init__( - self, - dims, - in_channels, - stride, - residual=False, - out_channels_reduction_factor=1, - spatial_padding_mode="zeros", - ): - super().__init__() - self.stride = stride - self.out_channels = ( - np.prod(stride) * in_channels // out_channels_reduction_factor - ) - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) - self.residual = residual - self.out_channels_reduction_factor = out_channels_reduction_factor - - def forward(self, x, causal: bool = True): - if self.residual: - # Reshape and duplicate the input to match the output shape - x_in = self.pixel_shuffle(x) - num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor - x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: - x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = self.pixel_shuffle(x) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x - - -class LayerNorm(nn.Module): - def __init__(self, dim, eps, elementwise_affine=True) -> None: - super().__init__() - self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward(self, x): - x = rearrange(x, "b c d h w -> b d h w c") - x = self.norm(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x - - -class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.inject_noise = inject_noise - - if norm_layer == "group_norm": - self.norm1 = nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) - - self.non_linearity = nn.SiLU() - - self.conv1 = make_conv_nd( - dims, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - if inject_noise: - self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) - - if norm_layer == "group_norm": - self.norm2 = nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = make_conv_nd( - dims, - out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - if inject_noise: - self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) - - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) - - self.norm3 = ( - LayerNorm(in_channels, eps=eps, elementwise_affine=True) - if in_channels != out_channels - else nn.Identity() - ) - - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.scale_shift_table = nn.Parameter( - torch.randn(4, in_channels) / in_channels**0.5 - ) - - def _feed_spatial_noise( - self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor - ) -> torch.FloatTensor: - spatial_shape = hidden_states.shape[-2:] - device = hidden_states.device - dtype = hidden_states.dtype - - # similar to the "explicit noise inputs" method in style-gan - spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] - scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] - hidden_states = hidden_states + scaled_noise - - return hidden_states - - def forward( - self, - input_tensor: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - hidden_states = input_tensor - batch_size = hidden_states.shape[0] - - hidden_states = self.norm1(hidden_states) - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - ada_values = self.scale_shift_table[ - None, ..., None, None, None - ] + timestep.reshape( - batch_size, - 4, - -1, - timestep.shape[-3], - timestep.shape[-2], - timestep.shape[-1], - ) - shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) - - hidden_states = hidden_states * (1 + scale1) + shift1 - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.conv1(hidden_states, causal=causal) - - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 - ) - - hidden_states = self.norm2(hidden_states) - - if self.timestep_conditioning: - hidden_states = hidden_states * (1 + scale2) + shift2 - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - - hidden_states = self.conv2(hidden_states, causal=causal) - - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 - ) - - input_tensor = self.norm3(input_tensor) - - batch_size = input_tensor.shape[0] - - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -def patchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - return x - - -def unpatchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x - - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - - return x - - -def create_video_autoencoder_demo_config( - latent_channels: int = 64, -): - encoder_blocks = [ - ("res_x", {"num_layers": 2}), - ("compress_space_res", {"multiplier": 2}), - ("compress_time_res", {"multiplier": 2}), - ("compress_all_res", {"multiplier": 2}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 1}), - ] - decoder_blocks = [ - ("res_x", {"num_layers": 2, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 2, "inject_noise": False}), - ] - return { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "encoder_blocks": encoder_blocks, - "decoder_blocks": decoder_blocks, - "latent_channels": latent_channels, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, - "timestep_conditioning": True, - "spatial_padding_mode": "replicate", - } - - -def test_vae_patchify_unpatchify(): - import torch - - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) - - -def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_demo_config() - - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = CausalVideoAutoencoder.from_config(config) - - print(video_autoencoder) - video_autoencoder.eval() - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 17, 64, 64) - - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") - - timestep = torch.ones(input_videos.shape[0]) * 0.1 - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape, timestep=timestep - ).sample - - print(f"reconstructed shape={reconstructed_videos.shape}") - - # Validate that single image gets treated the same way as first frame - input_image = input_videos[:, :, :1, :, :] - image_latent = video_autoencoder.encode(input_image).latent_dist.mode() - _ = video_autoencoder.decode( - image_latent, target_shape=image_latent.shape, timestep=timestep - ).sample - - first_frame_latent = latent[:, :, :1, :, :] - - assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) - # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() - - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - - # Perform backward pass - loss.backward() - - print(f"Demo completed with loss: {loss.item()}") - - -# Ensure to call the demo function to execute the forward and backward pass -if __name__ == "__main__": - demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py deleted file mode 100644 index 1aa55ed9c..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple, Union - -import torch - -from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d -from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d - - -def make_conv_nd( - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - kernel_size: int, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - causal=False, - spatial_padding_mode="zeros", - temporal_padding_mode="zeros", -): - if not (spatial_padding_mode == temporal_padding_mode or causal): - raise NotImplementedError("spatial and temporal padding modes must be equal") - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == 3: - if causal: - return CausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - spatial_padding_mode=spatial_padding_mode, - ) - return torch.nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == (2, 1): - return DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=bias, - padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - -def make_linear_nd( - dims: int, - in_channels: int, - out_channels: int, - bias=True, -): - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py deleted file mode 100644 index dcf889296..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py +++ /dev/null @@ -1,217 +0,0 @@ -import math -from typing import Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - - -class DualConv3d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - dilation: Union[int, Tuple[int, int, int]] = 1, - groups=1, - bias=True, - padding_mode="zeros", - ): - super(DualConv3d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.padding_mode = padding_mode - # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size, kernel_size) - if kernel_size == (1, 1, 1): - raise ValueError( - "kernel_size must be greater than 1. Use make_linear_nd instead." - ) - if isinstance(stride, int): - stride = (stride, stride, stride) - if isinstance(padding, int): - padding = (padding, padding, padding) - if isinstance(dilation, int): - dilation = (dilation, dilation, dilation) - - # Set parameters for convolutions - self.groups = groups - self.bias = bias - - # Define the size of the channels after the first convolution - intermediate_channels = ( - out_channels if in_channels < out_channels else in_channels - ) - - # Define parameters for the first convolution - self.weight1 = nn.Parameter( - torch.Tensor( - intermediate_channels, - in_channels // groups, - 1, - kernel_size[1], - kernel_size[2], - ) - ) - self.stride1 = (1, stride[1], stride[2]) - self.padding1 = (0, padding[1], padding[2]) - self.dilation1 = (1, dilation[1], dilation[2]) - if bias: - self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) - else: - self.register_parameter("bias1", None) - - # Define parameters for the second convolution - self.weight2 = nn.Parameter( - torch.Tensor( - out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 - ) - ) - self.stride2 = (stride[0], 1, 1) - self.padding2 = (padding[0], 0, 0) - self.dilation2 = (dilation[0], 1, 1) - if bias: - self.bias2 = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter("bias2", None) - - # Initialize weights and biases - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) - nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) - if self.bias: - fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) - bound1 = 1 / math.sqrt(fan_in1) - nn.init.uniform_(self.bias1, -bound1, bound1) - fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) - bound2 = 1 / math.sqrt(fan_in2) - nn.init.uniform_(self.bias2, -bound2, bound2) - - def forward(self, x, use_conv3d=False, skip_time_conv=False): - if use_conv3d: - return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) - else: - return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) - - def forward_with_3d(self, x, skip_time_conv): - # First convolution - x = F.conv3d( - x, - self.weight1, - self.bias1, - self.stride1, - self.padding1, - self.dilation1, - self.groups, - padding_mode=self.padding_mode, - ) - - if skip_time_conv: - return x - - # Second convolution - x = F.conv3d( - x, - self.weight2, - self.bias2, - self.stride2, - self.padding2, - self.dilation2, - self.groups, - padding_mode=self.padding_mode, - ) - - return x - - def forward_with_2d(self, x, skip_time_conv): - b, c, d, h, w = x.shape - - # First 2D convolution - x = rearrange(x, "b c d h w -> (b d) c h w") - # Squeeze the depth dimension out of weight1 since it's 1 - weight1 = self.weight1.squeeze(2) - # Select stride, padding, and dilation for the 2D convolution - stride1 = (self.stride1[1], self.stride1[2]) - padding1 = (self.padding1[1], self.padding1[2]) - dilation1 = (self.dilation1[1], self.dilation1[2]) - x = F.conv2d( - x, - weight1, - self.bias1, - stride1, - padding1, - dilation1, - self.groups, - padding_mode=self.padding_mode, - ) - - _, _, h, w = x.shape - - if skip_time_conv: - x = rearrange(x, "(b d) c h w -> b c d h w", b=b) - return x - - # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) - - # Reshape weight2 to match the expected dimensions for conv1d - weight2 = self.weight2.squeeze(-1).squeeze(-1) - # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution - stride2 = self.stride2[0] - padding2 = self.padding2[0] - dilation2 = self.dilation2[0] - x = F.conv1d( - x, - weight2, - self.bias2, - stride2, - padding2, - dilation2, - self.groups, - padding_mode=self.padding_mode, - ) - x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) - - return x - - @property - def weight(self): - return self.weight2 - - -def test_dual_conv3d_consistency(): - # Initialize parameters - in_channels = 3 - out_channels = 5 - kernel_size = (3, 3, 3) - stride = (2, 2, 2) - padding = (1, 1, 1) - - # Create an instance of the DualConv3d class - dual_conv3d = DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=True, - ) - - # Example input tensor - test_input = torch.randn(1, 3, 10, 10, 10) - - # Perform forward passes with both 3D and 2D settings - output_conv3d = dual_conv3d(test_input, use_conv3d=True) - output_2d = dual_conv3d(test_input, use_conv3d=False) - - # Assert that the outputs from both methods are sufficiently close - assert torch.allclose( - output_conv3d, output_2d, atol=1e-6 - ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py deleted file mode 100644 index 8cb7d7d68..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Optional, Union -from pathlib import Path -import os -import json - -import torch -import torch.nn as nn -from einops import rearrange -from diffusers import ConfigMixin, ModelMixin -from safetensors.torch import safe_open - -from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND - - -class ResBlock(nn.Module): - def __init__( - self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 - ): - super().__init__() - if mid_channels is None: - mid_channels = channels - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) - self.norm1 = nn.GroupNorm(32, mid_channels) - self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) - self.norm2 = nn.GroupNorm(32, channels) - self.activation = nn.SiLU() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = self.activation(x) - x = self.conv2(x) - x = self.norm2(x) - x = self.activation(x + residual) - return x - - -class LatentUpsampler(ModelMixin, ConfigMixin): - """ - Model to spatially upsample VAE latents. - - Args: - in_channels (`int`): Number of channels in the input latent - mid_channels (`int`): Number of channels in the middle layers - num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) - dims (`int`): Number of dimensions for convolutions (2 or 3) - spatial_upsample (`bool`): Whether to spatially upsample the latent - temporal_upsample (`bool`): Whether to temporally upsample the latent - """ - - def __init__( - self, - in_channels: int = 128, - mid_channels: int = 512, - num_blocks_per_stage: int = 4, - dims: int = 3, - spatial_upsample: bool = True, - temporal_upsample: bool = False, - ): - super().__init__() - - self.in_channels = in_channels - self.mid_channels = mid_channels - self.num_blocks_per_stage = num_blocks_per_stage - self.dims = dims - self.spatial_upsample = spatial_upsample - self.temporal_upsample = temporal_upsample - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) - self.initial_norm = nn.GroupNorm(32, mid_channels) - self.initial_activation = nn.SiLU() - - self.res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - if spatial_upsample and temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(3), - ) - elif spatial_upsample: - self.upsampler = nn.Sequential( - nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(2), - ) - elif temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(1), - ) - else: - raise ValueError( - "Either spatial_upsample or temporal_upsample must be True" - ) - - self.post_upsample_res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, latent: torch.Tensor) -> torch.Tensor: - b, c, f, h, w = latent.shape - - if self.dims == 2: - x = rearrange(latent, "b c f h w -> (b f) c h w") - x = self.initial_conv(x) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - x = self.upsampler(x) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - else: - x = self.initial_conv(latent) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - if self.temporal_upsample: - x = self.upsampler(x) - x = x[:, :, 1:, :, :] - else: - x = rearrange(x, "b c f h w -> (b f) c h w") - x = self.upsampler(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - - return x - - @classmethod - def from_config(cls, config): - return cls( - in_channels=config.get("in_channels", 4), - mid_channels=config.get("mid_channels", 128), - num_blocks_per_stage=config.get("num_blocks_per_stage", 4), - dims=config.get("dims", 2), - spatial_upsample=config.get("spatial_upsample", True), - temporal_upsample=config.get("temporal_upsample", False), - ) - - def config(self): - return { - "_class_name": "LatentUpsampler", - "in_channels": self.in_channels, - "mid_channels": self.mid_channels, - "num_blocks_per_stage": self.num_blocks_per_stage, - "dims": self.dims, - "spatial_upsample": self.spatial_upsample, - "temporal_upsample": self.temporal_upsample, - } - - @classmethod - def from_pretrained( - cls, - pretrained_model_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_path = Path(pretrained_model_path) - if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( - ".safetensors" - ): - state_dict = {} - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - config = json.loads(metadata["config"]) - with torch.device("meta"): - latent_upsampler = LatentUpsampler.from_config(config) - latent_upsampler.load_state_dict(state_dict, assign=True) - return latent_upsampler - - -if __name__ == "__main__": - latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) - print(latent_upsampler) - total_params = sum(p.numel() for p in latent_upsampler.parameters()) - print(f"Total number of parameters: {total_params:,}") - latent = torch.randn(1, 128, 9, 16, 16) - upsampled_latent = latent_upsampler(latent) - print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py deleted file mode 100644 index 9bc3ea60e..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch import nn - - -class PixelNorm(nn.Module): - def __init__(self, dim=1, eps=1e-8): - super(PixelNorm, self).__init__() - self.dim = dim - self.eps = eps - - def forward(self, x): - return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py deleted file mode 100644 index 4e79ae284..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch.nn as nn -from einops import rearrange - - -class PixelShuffleND(nn.Module): - def __init__(self, dims, upscale_factors=(2, 2, 2)): - super().__init__() - assert dims in [1, 2, 3], "dims must be 1, 2, or 3" - self.dims = dims - self.upscale_factors = upscale_factors - - def forward(self, x): - if self.dims == 3: - return rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - p3=self.upscale_factors[2], - ) - elif self.dims == 2: - return rearrange( - x, - "b (c p1 p2) h w -> b c (h p1) (w p2)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - ) - elif self.dims == 1: - return rearrange( - x, - "b (c p1) f h w -> b c (f p1) h w", - p1=self.upscale_factors[0], - ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py deleted file mode 100644 index 821a6b32b..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py +++ /dev/null @@ -1,380 +0,0 @@ -from typing import Optional, Union - -import torch -import inspect -import math -import torch.nn as nn -from diffusers import ConfigMixin, ModelMixin -from diffusers.models.autoencoders.vae import ( - DecoderOutput, - DiagonalGaussianDistribution, -) -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd - - -class AutoencoderKLWrapper(ModelMixin, ConfigMixin): - """Variational Autoencoder (VAE) model with KL loss. - - VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. - This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. - - Args: - encoder (`nn.Module`): - Encoder module. - decoder (`nn.Module`): - Decoder module. - latent_channels (`int`, *optional*, defaults to 4): - Number of latent channels. - """ - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - latent_channels: int = 4, - dims: int = 2, - sample_size=512, - use_quant_conv: bool = True, - normalize_latent_channels: bool = False, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = encoder - self.use_quant_conv = use_quant_conv - self.normalize_latent_channels = normalize_latent_channels - - # pass init params to Decoder - quant_dims = 2 if dims == 2 else 3 - self.decoder = decoder - if use_quant_conv: - self.quant_conv = make_conv_nd( - quant_dims, 2 * latent_channels, 2 * latent_channels, 1 - ) - self.post_quant_conv = make_conv_nd( - quant_dims, latent_channels, latent_channels, 1 - ) - else: - self.quant_conv = nn.Identity() - self.post_quant_conv = nn.Identity() - - if normalize_latent_channels: - if dims == 2: - self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.Identity() - self.use_z_tiling = False - self.use_hw_tiling = False - self.dims = dims - self.z_sample_size = 1 - - self.decoder_params = inspect.signature(self.decoder.forward).parameters - - # only relevant if vae tiling is enabled - self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) - - def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): - self.tile_sample_min_size = sample_size - num_blocks = len(self.encoder.down_blocks) - self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) - self.tile_overlap_factor = overlap_factor - - def enable_z_tiling(self, z_sample_size: int = 8): - r""" - Enable tiling during VAE decoding. - - When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_z_tiling = z_sample_size > 1 - self.z_sample_size = z_sample_size - assert ( - z_sample_size % 8 == 0 or z_sample_size == 1 - ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." - - def disable_z_tiling(self): - r""" - Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_z_tiling = False - - def enable_hw_tiling(self): - r""" - Enable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = True - - def disable_hw_tiling(self): - r""" - Disable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = False - - def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent - - # Split the image into 512x512 tiles and encode them separately. - rows = [] - for i in range(0, x.shape[3], overlap_size): - row = [] - for j in range(0, x.shape[4], overlap_size): - tile = x[ - :, - :, - :, - i : i + self.tile_sample_min_size, - j : j + self.tile_sample_min_size, - ] - tile = self.encoder(tile) - tile = self.quant_conv(tile) - row.append(tile) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - moments = torch.cat(result_rows, dim=3) - return moments - - def blend_z( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[2], b.shape[2], blend_extent) - for z in range(blend_extent): - b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( - 1 - z / blend_extent - ) + b[:, :, z, :, :] * (z / blend_extent) - return b - - def blend_v( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[3], b.shape[3], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( - 1 - y / blend_extent - ) + b[:, :, :, y, :] * (y / blend_extent) - return b - - def blend_h( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[4], b.shape[4], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( - 1 - x / blend_extent - ) + b[:, :, :, :, x] * (x / blend_extent) - return b - - def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent - tile_target_shape = ( - *target_shape[:3], - self.tile_sample_min_size, - self.tile_sample_min_size, - ) - # Split z into overlapping 64x64 tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, z.shape[3], overlap_size): - row = [] - for j in range(0, z.shape[4], overlap_size): - tile = z[ - :, - :, - :, - i : i + self.tile_latent_min_size, - j : j + self.tile_latent_min_size, - ] - tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, target_shape=tile_target_shape) - row.append(decoded) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - return dec - - def encode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - num_splits = z.shape[2] // self.z_sample_size - sizes = [self.z_sample_size] * num_splits - sizes = ( - sizes + [z.shape[2] - sum(sizes)] - if z.shape[2] - sum(sizes) > 0 - else sizes - ) - tiles = z.split(sizes, dim=2) - moments_tiles = [ - ( - self._hw_tiled_encode(z_tile, return_dict) - if self.use_hw_tiling - else self._encode(z_tile) - ) - for z_tile in tiles - ] - moments = torch.cat(moments_tiles, dim=2) - - else: - moments = ( - self._hw_tiled_encode(z, return_dict) - if self.use_hw_tiling - else self._encode(z) - ) - - posterior = DiagonalGaussianDistribution(moments) - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - _, c, _, _, _ = z.shape - z = torch.cat( - [ - self.latent_norm_out(z[:, : c // 2, :, :, :]), - z[:, c // 2 :, :, :, :], - ], - dim=1, - ) - elif isinstance(self.latent_norm_out, nn.BatchNorm2d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) - running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) - eps = self.latent_norm_out.eps - - z = z * torch.sqrt(running_var + eps) + running_mean - elif isinstance(self.latent_norm_out, nn.BatchNorm3d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - moments = self._normalize_latent_channels(moments) - return moments - - def _decode( - self, - z: torch.FloatTensor, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - z = self._unnormalize_latent_channels(z) - z = self.post_quant_conv(z) - if "timestep" in self.decoder_params: - dec = self.decoder(z, target_shape=target_shape, timestep=timestep) - else: - dec = self.decoder(z, target_shape=target_shape) - return dec - - def decode( - self, - z: torch.FloatTensor, - return_dict: bool = True, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - assert target_shape is not None, "target_shape must be provided for decoding" - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - reduction_factor = int( - self.encoder.patch_size_t - * 2 - ** ( - len(self.encoder.down_blocks) - - 1 - - math.sqrt(self.encoder.patch_size) - ) - ) - split_size = self.z_sample_size // reduction_factor - num_splits = z.shape[2] // split_size - - # copy target shape, and divide frame dimension (=2) by the context size - target_shape_split = list(target_shape) - target_shape_split[2] = target_shape[2] // num_splits - - decoded_tiles = [ - ( - self._hw_tiled_decode(z_tile, target_shape_split) - if self.use_hw_tiling - else self._decode(z_tile, target_shape=target_shape_split) - ) - for z_tile in torch.tensor_split(z, num_splits, dim=2) - ] - decoded = torch.cat(decoded_tiles, dim=2) - else: - decoded = ( - self._hw_tiled_decode(z, target_shape) - if self.use_hw_tiling - else self._decode(z, target_shape=target_shape, timestep=timestep) - ) - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`DecoderOutput`] instead of a plain tuple. - generator (`torch.Generator`, *optional*): - Generator used to sample from the posterior. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z, target_shape=sample.shape).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py deleted file mode 100644 index 5a0aeeccf..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py +++ /dev/null @@ -1,247 +0,0 @@ -from typing import Tuple -import torch -from diffusers import AutoencoderKL -from einops import rearrange -from torch import Tensor - - -from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( - CausalVideoAutoencoder, -) -from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( - Downsample3D, - VideoAutoencoder, -) - -try: - import torch_xla.core.xla_model as xm -except ImportError: - xm = None - - -def vae_encode( - media_items: Tensor, - vae: AutoencoderKL, - split_size: int = 1, - vae_per_channel_normalize=False, -) -> Tensor: - """ - Encodes media items (images or videos) into latent representations using a specified VAE model. - The function supports processing batches of images or video frames and can handle the processing - in smaller sub-batches if needed. - - Args: - media_items (Tensor): A torch Tensor containing the media items to encode. The expected - shape is (batch_size, channels, height, width) for images or (batch_size, channels, - frames, height, width) for videos. - vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, - pre-configured and loaded with the appropriate model weights. - split_size (int, optional): The number of sub-batches to split the input batch into for encoding. - If set to more than 1, the input media items are processed in smaller batches according to - this value. Defaults to 1, which processes all items in a single batch. - - Returns: - Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted - to match the input shape, scaled by the model's configuration. - - Examples: - >>> import torch - >>> from diffusers import AutoencoderKL - >>> vae = AutoencoderKL.from_pretrained('your-model-name') - >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. - >>> latents = vae_encode(images, vae) - >>> print(latents.shape) # Output shape will depend on the model's latent configuration. - - Note: - In case of a video, the function encodes the media item frame-by frame. - """ - is_video_shaped = media_items.dim() == 5 - batch_size, channels = media_items.shape[0:2] - - if channels != 3: - raise ValueError(f"Expects tensors with 3 channels, got {channels}.") - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - media_items = rearrange(media_items, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(media_items) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(media_items) // split_size - # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] - latents = [] - if media_items.device.type == "xla": - xm.mark_step() - for image_batch in media_items.split(encode_bs): - latents.append(vae.encode(image_batch).latent_dist.sample()) - if media_items.device.type == "xla": - xm.mark_step() - latents = torch.cat(latents, dim=0) - else: - latents = vae.encode(media_items).latent_dist.sample() - - latents = normalize_latents(latents, vae, vae_per_channel_normalize) - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) - return latents - - -def vae_decode( - latents: Tensor, - vae: AutoencoderKL, - is_video: bool = True, - split_size: int = 1, - vae_per_channel_normalize=False, - timestep=None, -) -> Tensor: - is_video_shaped = latents.dim() == 5 - batch_size = latents.shape[0] - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(latents) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(latents) // split_size - image_batch = [ - _run_decoder( - latent_batch, vae, is_video, vae_per_channel_normalize, timestep - ) - for latent_batch in latents.split(encode_bs) - ] - images = torch.cat(image_batch, dim=0) - else: - images = _run_decoder( - latents, vae, is_video, vae_per_channel_normalize, timestep - ) - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) - return images - - -def _run_decoder( - latents: Tensor, - vae: AutoencoderKL, - is_video: bool, - vae_per_channel_normalize=False, - timestep=None, -) -> Tensor: - if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): - *_, fl, hl, wl = latents.shape - temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) - latents = latents.to(vae.dtype) - vae_decode_kwargs = {} - if timestep is not None: - vae_decode_kwargs["timestep"] = timestep - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - target_shape=( - 1, - 3, - fl * temporal_scale if is_video else 1, - hl * spatial_scale, - wl * spatial_scale, - ), - **vae_decode_kwargs, - )[0] - else: - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - )[0] - return image - - -def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: - if isinstance(vae, CausalVideoAutoencoder): - spatial = vae.spatial_downscale_factor - temporal = vae.temporal_downscale_factor - else: - down_blocks = len( - [ - block - for block in vae.encoder.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - spatial = vae.config.patch_size * 2**down_blocks - temporal = ( - vae.config.patch_size_t * 2**down_blocks - if isinstance(vae, VideoAutoencoder) - else 1 - ) - - return (temporal, spatial, spatial) - - -def latent_to_pixel_coords( - latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False -) -> Tensor: - """ - Converts latent coordinates to pixel coordinates by scaling them according to the VAE's - configuration. - - Args: - latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] - containing the latent corner coordinates of each token. - vae (AutoencoderKL): The VAE model - causal_fix (bool): Whether to take into account the different temporal scale - of the first frame. Default = False for backwards compatibility. - Returns: - Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. - """ - - scale_factors = get_vae_size_scale_factor(vae) - causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix - pixel_coords = latent_to_pixel_coords_from_factors( - latent_coords, scale_factors, causal_fix - ) - return pixel_coords - - -def latent_to_pixel_coords_from_factors( - latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False -) -> Tensor: - pixel_coords = ( - latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] - ) - if causal_fix: - # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) - return pixel_coords - - -def normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) - / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents * vae.config.scaling_factor - ) - - -def un_normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents / vae.config.scaling_factor - ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py deleted file mode 100644 index 5b9ea640b..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py +++ /dev/null @@ -1,1045 +0,0 @@ -import json -import os -from functools import partial -from types import SimpleNamespace -from typing import Any, Mapping, Optional, Tuple, Union - -import torch -from einops import rearrange -from torch import nn -from torch.nn import functional - -from diffusers.utils import logging - -from maxdiffusion.models.ltx_video.utils.torch_utils import Identity -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm -from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper - -logger = logging.get_logger(__name__) - - -class VideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - video_vae = cls.from_config(config) - video_vae.to(kwargs["torch_dtype"]) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - ckpt_state_dict = torch.load(model_local_path) - video_vae.load_state_dict(ckpt_state_dict) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) - video_vae.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) - - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "VideoAutoencoder" - ), "config must have _class_name=VideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - - if use_quant_conv and latent_log_var == "uniform": - raise ValueError("uniform latent_log_var requires use_quant_conv=False") - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - dims = config["dims"] - return VideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="VideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels - // (self.encoder.patch_size_t * self.encoder.patch_size**2), - out_channels=self.decoder.conv_out.out_channels - // (self.decoder.patch_size_t * self.decoder.patch_size**2), - latent_channels=self.decoder.conv_in.in_channels, - block_out_channels=[ - self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels - for i in range(len(self.encoder.down_blocks)) - ], - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - patch_size_t=self.encoder.patch_size_t, - add_channel_padding=self.encoder.add_channel_padding, - ) - - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def downscale_factor(self): - return self.encoder.downsample_factor - - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - model_keys = set(name for name, _ in self.named_parameters()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - - converted_state_dict = {} - for key, value in state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - if "norm" in key and key not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer - - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - if add_channel_padding: - in_channels = in_channels * self.patch_size**3 - else: - in_channels = in_channels * self.patch_size_t * self.patch_size**2 - self.in_channels = in_channels - output_channel = block_out_channels[0] - - self.conv_in = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - padding=1, - ) - - self.down_blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels)): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - num_layers=self.layers_per_block, - add_downsample=not is_final_block and 2**i >= patch_size, - resnet_eps=1e-6, - downsample_padding=0, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.down_blocks.append(down_block) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], - num_groups=norm_num_groups, - eps=1e-6, - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, block_out_channels[-1], conv_out_channels, 3, padding=1 - ) - - self.gradient_checkpointing = False - - @property - def downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - * self.patch_size - ) - - def forward( - self, sample: torch.FloatTensor, return_features=False - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - downsample_in_time = sample.shape[2] != 1 - - # patchify - patch_size_t = self.patch_size_t if downsample_in_time else 1 - sample = patchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) - - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - if return_features: - features = [] - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)( - sample, downsample_in_time=downsample_in_time - ) - if return_features: - features.append(sample) - - sample = checkpoint_fn(self.mid_block)(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - - if return_features: - features.append(sample[:, : self.latent_channels, ...]) - return sample, features - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - """ - - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - if add_channel_padding: - out_channels = out_channels * self.patch_size**3 - else: - out_channels = out_channels * self.patch_size_t * self.patch_size**2 - self.out_channels = out_channels - - self.conv_in = make_conv_nd( - dims, - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - dims=dims, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block - and 2 ** (len(block_out_channels) - i - 1) > patch_size, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.up_blocks.append(up_block) - - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, block_out_channels[0], out_channels, 3, padding=1 - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - upsample_in_time = sample.shape[2] < target_shape[2] - - sample = self.conv_in(sample) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - sample = checkpoint_fn(self.mid_block)(sample) - sample = sample.to(upscale_dtype) - - for up_block in self.up_blocks: - sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - # un-patchify - patch_size_t = self.patch_size_t if upsample_in_time else 1 - sample = unpatchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) - - return sample - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_downsample: bool = True, - downsample_padding: int = 1, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - - if add_downsample: - self.downsample = Downsample3D( - dims, - out_channels, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsample = Identity() - - def forward( - self, hidden_states: torch.FloatTensor, downsample_in_time - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - hidden_states = self.downsample( - hidden_states, downsample_in_time=downsample_in_time - ) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_upsample: bool = True, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=input_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - - if add_upsample: - self.upsample = Upsample3D( - dims=dims, channels=out_channels, out_channels=out_channels - ) - else: - self.upsample = Identity() - - self.resolution_idx = resolution_idx - - def forward( - self, hidden_states: torch.FloatTensor, upsample_in_time=True - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - if norm_layer == "group_norm": - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() - - self.non_linearity = nn.SiLU() - - self.conv1 = make_conv_nd( - dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if norm_layer == "group_norm": - self.norm2 = torch.nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = make_conv_nd( - dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) - - def forward( - self, - input_tensor: torch.FloatTensor, - ) -> torch.FloatTensor: - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - - hidden_states = self.norm2(hidden_states) - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - - hidden_states = self.conv2(hidden_states) - - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -class Downsample3D(nn.Module): - def __init__( - self, - dims, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - padding: int = 1, - ): - super().__init__() - stride: int = 2 - self.padding = padding - self.in_channels = in_channels - self.dims = dims - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - - def forward(self, x, downsample_in_time=True): - conv = self.conv - if self.padding == 0: - if self.dims == 2: - padding = (0, 1, 0, 1) - else: - padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) - - x = functional.pad(x, padding, mode="constant", value=0) - - if self.dims == (2, 1) and not downsample_in_time: - return conv(x, skip_time_conv=True) - - return conv(x) - - -class Upsample3D(nn.Module): - """ - An upsampling layer for 3D tensors of shape (B, C, D, H, W). - - :param channels: channels in the inputs and outputs. - """ - - def __init__(self, dims, channels, out_channels=None): - super().__init__() - self.dims = dims - self.channels = channels - self.out_channels = out_channels or channels - self.conv = make_conv_nd( - dims, channels, out_channels, kernel_size=3, padding=1, bias=True - ) - - def forward(self, x, upsample_in_time): - if self.dims == 2: - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - else: - time_scale_factor = 2 if upsample_in_time else 1 - # print("before:", x.shape) - b, c, d, h, w = x.shape - x = rearrange(x, "b c d h w -> (b d) c h w") - # height and width interpolate - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - _, _, h, w = x.shape - - if not upsample_in_time and self.dims == (2, 1): - x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) - return self.conv(x, skip_time_conv=True) - - # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) - - # (b h w) c 1 d - new_d = x.shape[-1] * time_scale_factor - x = functional.interpolate(x, (1, new_d), mode="nearest") - # (b h w) c 1 new_d - x = rearrange( - x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d - ) - # b c d h w - - # x = functional.interpolate( - # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - # ) - # print("after:", x.shape) - - return self.conv(x) - - -def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] - padding_zeros = torch.zeros( - x.shape[0], - channels_to_pad, - x.shape[2], - x.shape[3], - x.shape[4], - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([padding_zeros, x], dim=1) - - return x - - -def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) - x = x[:, :channels_to_keep, :, :, :] - - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - - return x - - -def create_video_autoencoder_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [ - 128, - 256, - 512, - 512, - ], # Number of output channels of each encoder / decoder inner block - "patch_size": 1, - } - - return config - - -def create_video_autoencoder_pathify4x4x4_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "latent_log_var": "uniform", - } - - return config - - -def create_video_autoencoder_pathify4x4_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "norm_layer": "pixel_norm", - } - - return config - - -def test_vae_patchify_unpatchify(): - import torch - - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) - - -def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_pathify4x4x4_config() - - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = VideoAutoencoder.from_config(config) - - print(video_autoencoder) - - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 8, 64, 64) - - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape - ).sample - - print(f"reconstructed shape={reconstructed_videos.shape}") - - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - - # Perform backward pass - loss.backward() - - print(f"Demo completed with loss: {loss.item()}") - - -# Ensure to call the demo function to execute the forward and backward pass -if __name__ == "__main__": - demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py similarity index 100% rename from src/maxdiffusion/models/ltx_video/autoencoders/__init__.py rename to src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 06b367ce4..31c6b5b15 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -8,109 +8,109 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry - @nn.compact - def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: - data_input, index_input = carry + mod = self.module(*self.module_init_args, **self.module_init_kwargs) - mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers - output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers + next_index = index_input + 1 + new_carry = (output_data, next_index) - next_index = index_input + 1 - new_carry = (output_data, next_index) - + return new_carry, None - return new_carry, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - if not args: - raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + @nn.compact + def __call__(self, *args): + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") - initial_data_input = args[0] - static_block_args = args[1:] + initial_data_input = args[0] + static_block_args = args[1:] - initial_index = jnp.array(0, dtype=jnp.int32) #index of current transformer block + initial_index = jnp.array(0, dtype=jnp.int32) # index of current transformer block - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - in_axes_for_scan = (nn.broadcast,) * (len(args)-1) + in_axes_for_scan = (nn.broadcast,) * (len(args) - 1) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={"params": True}, - in_axes=in_axes_for_scan, - length=self.num_layers, - **scan_kwargs, - ) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) - wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) - # Call wrapped_function with the initial carry tuple and the static_block_args - (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) - return final_data \ No newline at end of file + return final_data diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index caff27add..75692b703 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -15,7 +15,6 @@ # This implementation is based on the Torch version available at: # https://github.com/Lightricks/LTX-Video/tree/main from functools import partial -import functools import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto @@ -40,899 +39,885 @@ ) - class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - def __call__(self, x): - return x + + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index=index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + block_index=-1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, + name="to_out.0", matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - - def __call__( - self, - index: int, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - skip_layer_strategy = SkipLayerStrategy.AttentionValues - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape( + skip_layer_mask[block_index], (batch_size, 1, 1) + ) # here skip_layer_mask is (48,3), changed this currently! + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - - attn_output = self.attn1( - norm_hidden_states, - block_index = index, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) - attn_output = nn.with_logical_constraint( - attn_output, ("activation_batch", "activation_norm_length", "activation_embed") - ) +class AttentionOp(nn.Module): - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint( - attn_input, ("activation_batch", "activation_norm_length", "activation_embed") - ) - attn_output = self.attn2( - attn_input, - block_index = -1, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint( - ff_output, ("activation_batch", "activation_norm_length", "activation_embed") - ) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="to_out.0", - matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any(device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print("Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - block_index: int = -1, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } - assert cross_attention_kwargs.get("scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) #here skip_layer_mask is (48,3), changed this currently! - - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - value_for_stg = value - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") - - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") + ** SRAM cache size for TPU + V5P - 1MB SRAM per core - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") - - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - hidden_states_a = self.attention( - query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype - ) +class ExplicitAttention(nn.Module): - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") - ) + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v - hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) - elif ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.AttentionValues - ): - hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( - 1.0 - skip_layer_mask - ) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. -class AttentionOp(nn.Module): - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # ("data", "fsdp", "fsdp_transpose", "expert"), - # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - # None, - # None, - # ) - # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - qkvo_sharding_spec = jax.sharding.PartitionSpec( - "data", - "fsdp", - None, - "tensor", - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. - - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can - - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM - - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size - - ** SRAM cache size for TPU - V5P - 1MB SRAM per core - - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used - - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + Returns: + jax.Array: Normed data + """ -class ExplicitAttention(nn.Module): - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. - - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. - - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data - - Returns: - jax.Array: Normed data - """ - - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) - - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError(f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out \ No newline at end of file + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py deleted file mode 100644 index 2eca32033..000000000 --- a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py +++ /dev/null @@ -1,84 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Tuple - -import torch -from diffusers.configuration_utils import ConfigMixin -from einops import rearrange -from torch import Tensor - - -class Patchifier(ConfigMixin, ABC): - def __init__(self, patch_size: int): - super().__init__() - self._patch_size = (1, patch_size, patch_size) - - @abstractmethod - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - raise NotImplementedError("Patchify method not implemented") - - @abstractmethod - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - pass - - @property - def patch_size(self): - return self._patch_size - - def get_latent_coords( - self, latent_num_frames, latent_height, latent_width, batch_size, device - ): - """ - Return a tensor of shape [batch_size, 3, num_patches] containing the - top-left corner latent coordinates of each latent patch. - The tensor is repeated for each batch element. - """ - latent_sample_coords = torch.meshgrid( - torch.arange(0, latent_num_frames, self._patch_size[0], device=device), - torch.arange(0, latent_height, self._patch_size[1], device=device), - torch.arange(0, latent_width, self._patch_size[2], device=device), - ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size - ) - return latent_coords - - -class SymmetricPatchifier(Patchifier): - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - b, _, f, h, w = latents.shape - latent_coords = self.get_latent_coords(f, h, w, b, latents.device) - latents = rearrange( - latents, - "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", - p1=self._patch_size[0], - p2=self._patch_size[1], - p3=self._patch_size[2], - ) - return latents, latent_coords - - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - output_height = output_height // self._patch_size[1] - output_width = output_width // self._patch_size[2] - latents = rearrange( - latents, - "b (f h w) (c p q) -> b c f (h p) (w q)", - h=output_height, - w=output_width, - p=self._patch_size[1], - q=self._patch_size[2], - ) - return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index e3110a82b..1c1807fdd 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -29,304 +29,298 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - def init_weights(self, in_channels, key, caption_channels, eval_only=True): - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "hidden_states": (batch_size, num_tokens, in_channels), - "indices_grid": (batch_size, 3, num_tokens), - "encoder_hidden_states": (batch_size, 128, caption_channels), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - if eval_only: - return jax.eval_shape( - self.init, - key, - **example_inputs, - )["params"] - else: - return self.init(key, **example_inputs)["params"] - def create_skip_layer_mask( - self, - batch_size: int, - num_conds: int, - ptb_index: int, - skip_block_list: Optional[List[int]] = None, - ) -> Optional[jnp.ndarray]: - if skip_block_list is None or len(skip_block_list) == 0: - return None - mask = jnp.ones( - (self.num_layers, batch_size * num_conds), dtype=self.dtype - ) - - for block_idx in skip_block_list: - mask = mask.at[block_idx, ptb_index::num_conds].set(0) - - return mask - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - skip_layer_mask=None, - skip_layer_strategy=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - - if self.num_layers > 0: - - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - skip_layer_mask, - skip_layer_strategy, - ) - # Output processing - - scale_shift_values = ( - self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] - ) - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( #noqa: C408 + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + key, + **example_inputs, + )["params"] + else: + return self.init(key, **example_inputs)["params"] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones((self.num_layers, batch_size * num_conds), dtype=self.dtype) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. - - Args: - x (jax.Array): log value - base (jax.Array): base of the log - - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) - + """ + Computes log of x with defined base. + Args: + x (jax.Array): log value + base (jax.Array): base of the log + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - dtype = jnp.float32 # We need full precision in the freqs_cis computation. - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py deleted file mode 100644 index 53c0082d1..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py +++ /dev/null @@ -1,174 +0,0 @@ -def make_hashable_key(dict_key): - def convert_value(value): - if isinstance(value, list): - return tuple(value) - elif isinstance(value, dict): - return tuple(sorted((k, convert_value(v)) for k, v in value.items())) - else: - return value - - return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) - - -DIFFUSERS_SCHEDULER_CONFIG = { - "_class_name": "FlowMatchEulerDiscreteScheduler", - "_diffusers_version": "0.32.0.dev0", - "base_image_seq_len": 1024, - "base_shift": 0.95, - "invert_sigmas": False, - "max_image_seq_len": 4096, - "max_shift": 2.05, - "num_train_timesteps": 1000, - "shift": 1.0, - "shift_terminal": 0.1, - "use_beta_sigmas": False, - "use_dynamic_shifting": True, - "use_exponential_sigmas": False, - "use_karras_sigmas": False, -} -DIFFUSERS_TRANSFORMER_CONFIG = { - "_class_name": "LTXVideoTransformer3DModel", - "_diffusers_version": "0.32.0.dev0", - "activation_fn": "gelu-approximate", - "attention_bias": True, - "attention_head_dim": 64, - "attention_out_bias": True, - "caption_channels": 4096, - "cross_attention_dim": 2048, - "in_channels": 128, - "norm_elementwise_affine": False, - "norm_eps": 1e-06, - "num_attention_heads": 32, - "num_layers": 28, - "out_channels": 128, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm_across_heads", -} -DIFFUSERS_VAE_CONFIG = { - "_class_name": "AutoencoderKLLTXVideo", - "_diffusers_version": "0.32.0.dev0", - "block_out_channels": [128, 256, 512, 512], - "decoder_causal": False, - "encoder_causal": True, - "in_channels": 3, - "latent_channels": 128, - "layers_per_block": [4, 3, 3, 3, 4], - "out_channels": 3, - "patch_size": 4, - "patch_size_t": 1, - "resnet_norm_eps": 1e-06, - "scaling_factor": 1.0, - "spatio_temporal_scaling": [True, True, True, False], -} - -OURS_SCHEDULER_CONFIG = { - "_class_name": "RectifiedFlowScheduler", - "_diffusers_version": "0.25.1", - "num_train_timesteps": 1000, - "shifting": "SD3", - "base_resolution": None, - "target_shift_terminal": 0.1, -} - -OURS_TRANSFORMER_CONFIG = { - "_class_name": "Transformer3DModel", - "_diffusers_version": "0.25.1", - "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", - "activation_fn": "gelu-approximate", - "attention_bias": True, - "attention_head_dim": 64, - "attention_type": "default", - "caption_channels": 4096, - "cross_attention_dim": 2048, - "double_self_attention": False, - "dropout": 0.0, - "in_channels": 128, - "norm_elementwise_affine": False, - "norm_eps": 1e-06, - "norm_num_groups": 32, - "num_attention_heads": 32, - "num_embeds_ada_norm": 1000, - "num_layers": 28, - "num_vector_embeds": None, - "only_cross_attention": False, - "out_channels": 128, - "project_to_2d_pos": True, - "upcast_attention": False, - "use_linear_projection": False, - "qk_norm": "rms_norm", - "standardization_norm": "rms_norm", - "positional_embedding_type": "rope", - "positional_embedding_theta": 10000.0, - "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000, -} -OURS_VAE_CONFIG = { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "in_channels": 3, - "out_channels": 3, - "latent_channels": 128, - "blocks": [ - ["res_x", 4], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x", 3], - ["res_x", 4], - ], - "scaling_factor": 1.0, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, -} - - -diffusers_and_ours_config_mapping = { - make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, - make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, - make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, -} - - -TRANSFORMER_KEYS_RENAME_DICT = { - "proj_in": "patchify_proj", - "time_embed": "adaln_single", - "norm_q": "q_norm", - "norm_k": "k_norm", -} - - -VAE_KEYS_RENAME_DICT = { - "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", - "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", - "decoder.up_blocks.3": "decoder.up_blocks.9", - "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", - "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", - "decoder.up_blocks.2": "decoder.up_blocks.6", - "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", - "decoder.up_blocks.1": "decoder.up_blocks.3", - "decoder.up_blocks.0": "decoder.up_blocks.1", - "decoder.mid_block": "decoder.up_blocks.0", - "encoder.down_blocks.3": "encoder.down_blocks.8", - "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", - "encoder.down_blocks.2": "encoder.down_blocks.6", - "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", - "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", - "encoder.down_blocks.1": "encoder.down_blocks.3", - "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", - "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", - "encoder.down_blocks.0": "encoder.down_blocks.0", - "encoder.mid_block": "encoder.down_blocks.9", - "conv_shortcut.conv": "conv_shortcut", - "resnets": "res_blocks", - "norm3": "norm3.norm", - "latents_mean": "per_channel_statistics.mean-of-means", - "latents_std": "per_channel_statistics.std-of-means", -} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py deleted file mode 100644 index 901051728..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py +++ /dev/null @@ -1,226 +0,0 @@ -import logging -from typing import Union, List, Optional - -import torch -from PIL import Image - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. -Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. -Start directly with the action, and keep descriptions literal and precise. -Think like a cinematographer describing a shot list. -Do not change the user input intent, just enhance it. -Keep within 150 words. -For best results, build your prompts using this structure: -Start with main action in a single sentence -Add specific details about movements and gestures -Describe character/object appearances precisely -Include background and environment details -Specify camera angles and movements -Describe lighting and colors -Note any changes or sudden events -Do not exceed the 150 word limit! -Output the enhanced prompt only. -""" - -I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. -Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. -Start directly with the action, and keep descriptions literal and precise. -Think like a cinematographer describing a shot list. -Keep within 150 words. -For best results, build your prompts using this structure: -Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. -Start with main action in a single sentence -Add specific details about movements and gestures -Describe character/object appearances precisely -Include background and environment details -Specify camera angles and movements -Describe lighting and colors -Note any changes or sudden events -Align to the image caption if it contradicts the user text input. -Do not exceed the 150 word limit! -Output the enhanced prompt only. -""" - - -def tensor_to_pil(tensor): - # Ensure tensor is in range [-1, 1] - assert tensor.min() >= -1 and tensor.max() <= 1 - - # Convert from [-1, 1] to [0, 1] - tensor = (tensor + 1) / 2 - - # Rearrange from [C, H, W] to [H, W, C] - tensor = tensor.permute(1, 2, 0) - - # Convert to numpy array and then to uint8 range [0, 255] - numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") - - # Convert to PIL Image - return Image.fromarray(numpy_image) - - -def generate_cinematic_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompt: Union[str, List[str]], - conditioning_items: Optional[List] = None, - max_new_tokens: int = 256, -) -> List[str]: - prompts = [prompt] if isinstance(prompt, str) else prompt - - if conditioning_items is None: - prompts = _generate_t2v_prompt( - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - max_new_tokens, - T2V_CINEMATIC_PROMPT, - ) - else: - if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: - logger.warning( - "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" - ) - return prompts - - first_frame_conditioning_item = conditioning_items[0] - first_frames = _get_first_frames_from_conditioning_item( - first_frame_conditioning_item - ) - - assert len(first_frames) == len( - prompts - ), "Number of conditioning frames must match number of prompts" - - prompts = _generate_i2v_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - first_frames, - max_new_tokens, - I2V_CINEMATIC_PROMPT, - ) - - return prompts - - -def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: - frames_tensor = conditioning_item.media_item - return [ - tensor_to_pil(frames_tensor[i, :, 0, :, :]) - for i in range(frames_tensor.shape[0]) - ] - - -def _generate_t2v_prompt( - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts: List[str], - max_new_tokens: int, - system_prompt: str, -) -> List[str]: - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}"}, - ] - for p in prompts - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) - - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) - - -def _generate_i2v_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts: List[str], - first_frames: List[Image.Image], - max_new_tokens: int, - system_prompt: str, -) -> List[str]: - image_captions = _generate_image_captions( - image_caption_model, image_caption_processor, first_frames - ) - - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, - ] - for p, c in zip(prompts, image_captions) - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) - - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) - - -def _generate_image_captions( - image_caption_model, - image_caption_processor, - images: List[Image.Image], - system_prompt: str = "", -) -> List[str]: - image_caption_prompts = [system_prompt] * len(images) - inputs = image_caption_processor( - image_caption_prompts, images, return_tensors="pt" - ).to(image_caption_model.device) - - with torch.inference_mode(): - generated_ids = image_caption_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - do_sample=False, - num_beams=3, - ) - - return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) - - -def _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int -) -> List[str]: - with torch.inference_mode(): - outputs = prompt_enhancer_model.generate( - **model_inputs, max_new_tokens=max_new_tokens - ) - generated_ids = [ - output_ids[len(input_ids) :] - for input_ids, output_ids in zip(model_inputs.input_ids, outputs) - ] - decoded_prompts = prompt_enhancer_tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - ) - - return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py deleted file mode 100644 index 30f9016e1..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum, auto - - -class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py deleted file mode 100644 index 991b07c36..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch import nn - - -def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - elif dims_to_append == 0: - return x - return x[(...,) + (None,) * dims_to_append] - - -class Identity(nn.Module): - """A placeholder identity operator that is argument-insensitive.""" - - def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument - super().__init__() - - # pylint: disable=unused-argument - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - return x diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 5584fcc0f..c4767e8ee 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -14,26 +14,26 @@ import math import os from jax import Array -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from maxdiffusion.models.ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler from diffusers import AutoencoderKL from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor from transformers import ( - FlaxT5EncoderModel, - AutoTokenizer, + FlaxT5EncoderModel, AutoModelForCausalLM, AutoProcessor, - AutoTokenizer,) + AutoTokenizer, +) import json import numpy as np import torch from huggingface_hub import hf_hub_download -from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( +from maxdiffusion.models.ltx_video.models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( +from maxdiffusion.models.ltx_video.models.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, vae_decode, @@ -42,935 +42,739 @@ normalize_latents, ) from diffusers.image_processor import VaeImageProcessor -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt -from math import e from types import NoneType from typing import Any, Dict -import numpy as np import jax import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P -from typing import Optional, Union, List +from jax.sharding import Mesh from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from ...pyconfig import HyperParameters from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState -from ...max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations -) +from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import json import functools import orbax.checkpoint as ocp + def prepare_extra_step_kwargs(generator): - extra_step_kwargs = {} - extra_step_kwargs["generator"] = generator - return extra_step_kwargs + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs class LTXVideoPipeline: - def __init__( - self, - transformer: Transformer3DModel, - scheduler: FlaxRectifiedFlowMultistepScheduler, - scheduler_state: RectifiedFlowSchedulerState, - vae: AutoencoderKL, - text_encoder, - patchifier, - tokenizer, + + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: AutoencoderKL, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True + ) + ##load in jax weights checkpoint + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + return vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return ( prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - transformer_state: Dict[Any, Any] = None, - transformer_state_shardings: Dict[Any, Any] = NoneType, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) + + # load from pytorch version + models_dir = config.models_dir + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + vae = cls.load_vae(ltxv_model_path) + vae = vae.to(torch.bfloat16) + text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = cls.load_prompt_enhancement(config) + else: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = (None, None, None, None) + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings, + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( # currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps( + state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps + ) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps >= num_inference_steps ): - self.transformer = transformer - self.devices_array = devices_array - self.mesh = mesh - self.config = config - self.p_run_inference = None - self.transformer_state = transformer_state - self.transformer_state_shardings = transformer_state_shardings - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.vae = vae - self.text_encoder = text_encoder - self.patchifier = patchifier - self.tokenizer = tokenizer - self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model - self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor - self.prompt_enhancer_llm_model = prompt_enhancer_llm_model - self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer - self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( - self.vae - ) - self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor) - - @classmethod - def load_scheduler(cls, ckpt_path, config): - if config.sampler == "from_checkpoint" or not config.sampler: - scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) - else: - scheduler = FlaxRectifiedFlowMultistepScheduler( - sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") - ) - scheduler_state = scheduler.create_state() - - return scheduler, scheduler_state - - @classmethod - def load_transformer(cls, config): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", - "causal_temporal_positioning", "in_channels", "ckpt_path"] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) - - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True - ) - ##load in jax weights checkpoint - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - transformer_state = jax.device_put( - transformer_state, transformer_state_shardings) - get_memory_allocations() - - return transformer, transformer_state, transformer_state_shardings - - @classmethod - def load_vae(cls, ckpt_path): - vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) - return vae - - @classmethod - def load_text_encoder(cls, ckpt_path): - t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) - return t5_encoder - - @classmethod - def load_tokenizer(cls, config, ckpt_path): - t5_tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, max_length=config.max_sequence_length, use_fast=True - ) - return t5_tokenizer - - @classmethod - def load_prompt_enhancement(cls, config): - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", - ) - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, - ) - return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer - - @classmethod - def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - transformer, transformer_state, transformer_state_shardings = cls.load_transformer( - config) - - # load from pytorch version - models_dir = config.models_dir - ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" - if not os.path.isfile(ltxv_model_name_or_path): - ltxv_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=ltxv_model_name_or_path, - local_dir=models_dir, - repo_type="model", - ) - else: - ltxv_model_path = ltxv_model_name_or_path - - scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) - vae = cls.load_vae(ltxv_model_path) - vae = vae.to(torch.bfloat16) - text_encoder = cls.load_text_encoder( - config.text_encoder_model_name_or_path) - patchifier = SymmetricPatchifier(patch_size=1) - tokenizer = cls.load_tokenizer( - config, config.text_encoder_model_name_or_path) - - enhance_prompt = False - if enhance_prompt: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( - config) - else: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None - - return LTXVideoPipeline( - transformer=transformer, - scheduler=scheduler, - scheduler_state=scheduler_state, - vae=vae, - text_encoder=text_encoder, - patchifier=patchifier, - tokenizer=tokenizer, - prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model=prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, - devices_array=devices_array, - mesh=mesh, - config=config, - transformer_state=transformer_state, - transformer_state_shardings=transformer_state_shardings - ) - - @classmethod - def _text_preprocessing(self, text): - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - text = text.strip() - return text - - return [process(t) for t in text] - - def denoising_step( - scheduler, - latents: Array, - noise_pred: Array, - current_timestep: Optional[Array], - conditioning_mask: Optional[Array], - t: float, - extra_step_kwargs: Dict, - t_eps: float = 1e-6, - stochastic_sampling: bool = False, - ) -> Array: - # Denoise the latents using the scheduler - denoised_latents = scheduler.step( - noise_pred, - t if current_timestep is None else current_timestep, - latents, - **extra_step_kwargs, - stochastic_sampling=stochastic_sampling, - ) - - if conditioning_mask is None: - return denoised_latents - - tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) - tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) - return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) - - def retrieve_timesteps( #currently doesn't support custom timesteps - self, - scheduler: FlaxRectifiedFlowMultistepScheduler, + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = text_encoder_max_tokens + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) #noqa: F841 + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype #noqa: F841 + elif self.transformer is not None: + dtype = self.transformer.dtype #noqa: F841 + else: + dtype = None #noqa: F841 + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) + prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( ## this is in pytorch + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items, + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor((b, f * h * w, c), generator=generator, device=device, dtype=dtype) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( # removed conditioning_item logic + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, CausalVideoAutoencoder) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True, + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) + import pdb + + pdb.set_trace() + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps( + self.scheduler, latent_shape, - scheduler_state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - timesteps: Optional[List[int]] = None, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - ): - scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) - timesteps = scheduler_state.timesteps - if ( - skip_initial_inference_steps < 0 - or skip_final_inference_steps < 0 - or skip_initial_inference_steps + skip_final_inference_steps - >= num_inference_steps - ): - raise ValueError( - "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" - ) - timesteps = timesteps[ - skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + self.scheduler_state, + num_inference_steps, + None, + skip_initial_inference_steps, + skip_final_inference_steps, + ) + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] + guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) + for skip_blocks in skip_block_list ] - scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) - - - return scheduler_state - - def encode_prompt( - self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - text_encoder_max_tokens: int = 256, - **kwargs, - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - max_length = ( - text_encoder_max_tokens - ) - if prompt_embeds is None: - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," - - prompt = self._text_preprocessing(prompt) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, max_length - 1: -1] - ) - - prompt_attention_mask = jnp.array(text_inputs.attention_mask) - prompt_embeds = self.text_encoder( - text_input_ids, attention_mask=prompt_attention_mask - ) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype - else: - dtype = None - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = jnp.reshape( - prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) - prompt_attention_mask = jnp.tile( - prompt_attention_mask, (1, num_images_per_prompt)) - prompt_attention_mask = jnp.reshape( - prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = self._text_preprocessing(negative_prompt) - uncond_tokens = uncond_tokens * batch_size - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) - - negative_prompt_embeds = self.text_encoder( - jnp.array(uncond_input.input_ids), - attention_mask=negative_prompt_attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = jnp.tile(negative_prompt_embeds, - (1, num_images_per_prompt, 1) - ) - negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, - (batch_size * num_images_per_prompt, seq_len, -1) - ) - - negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, - (1, num_images_per_prompt) - ) - negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, - (bs_embed * num_images_per_prompt, -1) - ) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - - return ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) - - def prepare_latents( ## this is in pytorch - self, - latents: torch.Tensor | None, - media_items: torch.Tensor | None, - timestep: float, - latent_shape: torch.Size | Tuple[Any, ...], - dtype: torch.dtype, - device: torch.device, - generator: torch.Generator | List[torch.Generator], - vae_per_channel_normalize: bool = True, - ): - if isinstance(generator, list) and len(generator) != latent_shape[0]: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." - ) - - # Initialize the latents with the given latents or encoded media item, if provided - assert ( - latents is None or media_items is None - ), "Cannot provide both latents and media_items. Please provide only one of the two." - - assert ( - latents is None and media_items is None or timestep < 1.0 - ), "Input media_item or latents are provided, but they will be replaced with noise." - - if media_items is not None: - latents = vae_encode( - media_items, - self.vae, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - if latents is not None: - assert ( - latents.shape == latent_shape - ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." - - # For backward compatibility, generate in the "patchified" shape and rearrange - b, c, f, h, w = latent_shape - noise = randn_tensor( - (b, f * h * w, c), generator=generator, device=device, dtype=dtype - ) - noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) - - # scale the initial noise by the standard deviation required by the scheduler - # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have - - if latents is None: - latents = noise - else: - # Noise the latents to the required (first) timestep - timestep = torch.from_numpy(np.array(timestep)) - latents = timestep * noise + (1 - timestep) * latents - - return latents - - def prepare_conditioning( - self, - conditioning_items, - init_latents: torch.Tensor, - num_frames: int, - height: int, - width: int, - vae_per_channel_normalize: bool = True, - generator=None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - - assert isinstance(self.vae, CausalVideoAutoencoder) - - if conditioning_items: - batch_size, _, num_latent_frames = init_latents.shape[:3] - - init_conditioning_mask = torch.zeros( - init_latents[:, 0, :, :, :].shape, - dtype=torch.float32, - device=init_latents.device, - ) - - extra_conditioning_latents = [] - extra_conditioning_pixel_coords = [] - extra_conditioning_mask = [] - extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) - - # Process each conditioning item - for conditioning_item in conditioning_items: - conditioning_item = self._resize_conditioning_item( - conditioning_item, height, width - ) - media_item = conditioning_item.media_item - media_frame_number = conditioning_item.media_frame_number - strength = conditioning_item.conditioning_strength - assert media_item.ndim == 5 # (b, c, f, h, w) - b, c, n_frames, h, w = media_item.shape - assert ( - height == h and width == w - ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - assert n_frames % 8 == 1 - assert ( - media_frame_number >= 0 - and media_frame_number + n_frames <= num_frames - ) - - # Encode the provided conditioning media item - media_item_latents = vae_encode( - media_item.to(dtype=self.vae.dtype, device=self.vae.device), - self.vae, - vae_per_channel_normalize=vae_per_channel_normalize, - ).to(dtype=init_latents.dtype) - - # Handle the different conditioning cases - if media_frame_number == 0: - # Get the target spatial position of the latent conditioning item - media_item_latents, l_x, l_y = self._get_latent_spatial_position( - media_item_latents, - conditioning_item, - height, - width, - strip_latent_border=True, - ) - b, c_l, f_l, h_l, w_l = media_item_latents.shape - - # First frame or sequence - just update the initial noise latents and the mask - init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( - torch.lerp( - init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], - media_item_latents, - strength, - ) - ) - init_conditioning_mask[ - :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l - ] = strength - else: - # Non-first frame or sequence - if n_frames > 1: - # Handle non-first sequence. - # Encoded latents are either fully consumed, or the prefix is handled separately below. - ( - init_latents, - init_conditioning_mask, - media_item_latents, - ) = self._handle_non_first_conditioning_sequence( - init_latents, - init_conditioning_mask, - media_item_latents, - media_frame_number, - strength, - ) - - # Single frame or sequence-prefix latents - if media_item_latents is not None: - noise = randn_tensor( - media_item_latents.shape, - generator=generator, - device=media_item_latents.device, - dtype=media_item_latents.dtype, - ) - - media_item_latents = torch.lerp( - noise, media_item_latents, strength - ) - - # Patchify the extra conditioning latents and calculate their pixel coordinates - media_item_latents, latent_coords = self.patchifier.patchify( - latents=media_item_latents - ) - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.vae, - causal_fix=self.transformer.config.causal_temporal_positioning, - ) - - # Update the frame numbers to match the target frame number - pixel_coords[:, 0] += media_frame_number - extra_conditioning_num_latents += media_item_latents.shape[1] - - conditioning_mask = torch.full( - media_item_latents.shape[:2], - strength, - dtype=torch.float32, - device=init_latents.device, - ) - - extra_conditioning_latents.append(media_item_latents) - extra_conditioning_pixel_coords.append(pixel_coords) - extra_conditioning_mask.append(conditioning_mask) - - # Patchify the updated latents and calculate their pixel coordinates - init_latents, init_latent_coords = self.patchifier.patchify( - latents=init_latents - ) - init_pixel_coords = latent_to_pixel_coords( - init_latent_coords, - self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now - causal_fix=True - ) - - if not conditioning_items: - return init_latents, init_pixel_coords, None, 0 - - init_conditioning_mask, _ = self.patchifier.patchify( - latents=init_conditioning_mask.unsqueeze(1) - ) - init_conditioning_mask = init_conditioning_mask.squeeze(-1) - - if extra_conditioning_latents: - # Stack the extra conditioning latents, pixel coordinates and mask - init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) - init_pixel_coords = torch.cat( - [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 - ) - init_conditioning_mask = torch.cat( - [*extra_conditioning_mask, init_conditioning_mask], dim=1 - ) - - if self.transformer.use_tpu_flash_attention: - # When flash attention is used, keep the original number of tokens by removing - # tokens from the end. - init_latents = init_latents[:, :-extra_conditioning_num_latents] - init_pixel_coords = init_pixel_coords[ - :, :, :-extra_conditioning_num_latents - ] - init_conditioning_mask = init_conditioning_mask[ - :, :-extra_conditioning_num_latents - ] - - return ( - init_latents, - init_pixel_coords, - init_conditioning_mask, - extra_conditioning_num_latents, - ) - - - - def __call__( - self, - height: int, - width: int, - num_frames: int, - negative_prompt: str = "", - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - frame_rate: int = 30, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - guidance_timesteps: Optional[List[int]] = None, - decode_timestep: Union[List[float], float] = 0.05, - decode_noise_scale: Optional[List[float]] = 0.025, - offload_to_cpu: bool = False, - enhance_prompt: bool = False, - text_encoder_max_tokens: int = 256, - num_inference_steps: int = 50, - guidance_scale: Union[float, List[float]] = 4.5, - rescaling_scale: Union[float, List[float]] = 0.7, - stg_scale: Union[float, List[float]] = 1.0, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - cfg_star_rescale: bool = False, - skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, - **kwargs, - ): - enhance_prompt = False - prompt = self.config.prompt - is_video = kwargs.get("is_video", False) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - vae_per_channel_normalize = kwargs.get( - "vae_per_channel_normalize", True) - import pdb; pdb.set_trace() - - latent_height = height // self.vae_scale_factor - latent_width = width // self.vae_scale_factor - latent_num_frames = num_frames // self.video_scale_factor - if isinstance(self.vae, CausalVideoAutoencoder) and is_video: - latent_num_frames += 1 - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - - latent_shape = ( - batch_size * num_images_per_prompt, - model_config["in_channels"], - latent_num_frames, - latent_height, - latent_width, - ) - scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) - - - guidance_mapping = [] - - if guidance_timesteps: - for timestep in scheduler_state.timesteps: - indices = [ - i for i, val in enumerate(guidance_timesteps) if val <= timestep - ] - guidance_mapping.append( - indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) - ) - - if not isinstance(guidance_scale, list): - guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) - else: - guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(stg_scale, list): - stg_scale = [stg_scale] * len(scheduler_state.timesteps) - else: - stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(rescaling_scale, list): - rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) - else: - rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) - do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) - do_rescaling = any(x != 1.0 for x in rescaling_scale) - - - num_conds = 1 - if do_classifier_free_guidance: - num_conds += 1 - if do_spatio_temporal_guidance: - num_conds += 1 - - is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) - - if not is_list_of_lists: - skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) - else: - new_skip_block_list = [] - for i in range(len(scheduler_state.timesteps)): - new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) - - skip_block_list = new_skip_block_list - - if do_spatio_temporal_guidance: - if skip_block_list is not None: - skip_layer_masks = [ - self.transformer.create_skip_layer_mask( - batch_size, num_conds, num_conds - 1, skip_blocks - ) - for skip_blocks in skip_block_list - ] - if enhance_prompt: - prompt = generate_cinematic_prompt( - self.prompt_enhancer_image_caption_model, - self.prompt_enhancer_image_caption_processor, - self.prompt_enhancer_llm_model, - self.prompt_enhancer_llm_tokenizer, - prompt, - None, # conditioning items set to None - max_new_tokens=text_encoder_max_tokens, - ) - - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=None, # device set to none - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_max_tokens=text_encoder_max_tokens, - ) - prompt_embeds_batch = prompt_embeds - prompt_attention_mask_batch = prompt_attention_mask - if do_classifier_free_guidance: - prompt_embeds_batch = jnp.concatenate( - [negative_prompt_embeds, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [negative_prompt_attention_mask, prompt_attention_mask], axis=0 - ) - if do_spatio_temporal_guidance: - prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [ - prompt_attention_mask_batch, - prompt_attention_mask, - ], - axis=0, - ) - latents = self.prepare_latents( - latents=latents, - media_items=None, # set to None - timestep=scheduler_state.timesteps[0], - latent_shape=latent_shape, - dtype=None, - device=None, - generator=generator, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - - latents, pixel_coords, conditioning_mask, num_cond_latents = ( - self.prepare_conditioning( - conditioning_items=None, - init_latents=latents, - num_frames=num_frames, - height=height, - width=width, - vae_per_channel_normalize=vae_per_channel_normalize, - generator=generator, - ) - ) - - extra_step_kwargs = prepare_extra_step_kwargs( - generator=jax.random.PRNGKey(0)) - - pixel_coords = torch.cat([pixel_coords] * num_conds) - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - - noise_cond = jnp.ones( # initialize first round with this! - (1, 1) - ) - segment_ids = None #how is this created? - p_run_inference = functools.partial( - run_inference, - transformer=self.transformer, - config=self.config, - mesh=self.mesh, - fractional_cords=jnp.array( - fractional_coords.to(torch.float32).detach().numpy()), - prompt_embeds=prompt_embeds_batch, - segment_ids=None, - encoder_attention_segment_ids=prompt_attention_mask_batch, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - do_classifier_free_guidance=do_classifier_free_guidance, - num_conds=num_conds, - guidance_scale=guidance_scale, - do_spatio_temporal_guidance=do_spatio_temporal_guidance, - stg_scale=stg_scale, - do_rescaling = do_rescaling, - rescaling_scale = rescaling_scale, - batch_size=batch_size, - skip_layer_masks = skip_layer_masks, - cfg_star_rescale = cfg_star_rescale - ) - - with self.mesh: - latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( - torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) - latents = torch.from_numpy(np.array(latents)) - latents = latents[:, num_cond_latents:] - - latents = self.patchifier.unpatchify( - latents=latents, - output_height=latent_height, - output_width=latent_width, - out_channels=model_config["in_channels"] - // math.prod(self.patchifier.patch_size), - ) - if output_type != "latent": - if self.vae.decoder.timestep_conditioning: - noise = torch.randn_like(latents) - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * latents.shape[0] - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [ - decode_noise_scale] * latents.shape[0] - - decode_timestep = torch.tensor( - decode_timestep).to(latents.device) - decode_noise_scale = torch.tensor(decode_noise_scale).to( - latents.device - )[:, None, None, None, None] - latents = ( - latents * (1 - decode_noise_scale) + - noise * decode_noise_scale - ) - else: - decode_timestep = None - image = vae_decode( - latents, - self.vae, - is_video, - vae_per_channel_normalize=kwargs.get( - "vae_per_channel_normalize", True), - timestep=decode_timestep, - ) - image = self.image_processor.postprocess( - image, output_type=output_type) - - else: - image = latents - - # Offload all models - - if not return_dict: - return (image,) - - return image - - - -def transformer_forward_pass( # need to jit this? wan didnt + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + noise_cond = jnp.ones((1, 1)) # initialize first round with this! + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array(fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling=do_rescaling, + rescaling_scale=rescaling_scale, + batch_size=batch_size, + skip_layer_masks=skip_layer_masks, + cfg_star_rescale=cfg_star_rescale, + ) + + with self.mesh: + latents, scheduler_state = p_run_inference( + transformer_state=self.transformer_state, + latents=jnp.array(latents.to(torch.float32).detach().numpy()), + timestep=noise_cond, + scheduler_state=scheduler_state, + ) + latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[:, None, None, None, None] + latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + image = self.image_processor.postprocess(image, output_type=output_type) + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + + +def transformer_forward_pass( latents, state, noise_cond, @@ -981,220 +785,234 @@ def transformer_forward_pass( # need to jit this? wan didnt encoder_attention_segment_ids, skip_layer_mask, ): - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - skip_layer_mask=skip_layer_mask, - ) - return noise_pred, state + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + ) + return noise_pred, state def run_inference( - transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale -): - for i, t in enumerate(scheduler_state.timesteps): - current_timestep = t - latent_model_input = ( - jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents - ) - if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - is_mps = False - if isinstance(current_timestep, float): - dtype = jnp.float32 - else: - dtype = jnp.int32 - - current_timestep = jnp.array( - [current_timestep], - dtype=dtype, - ) - elif current_timestep.ndim == 0: - current_timestep = jnp.expand_dims(current_timestep, axis=0) - - # Broadcast to batch dimension - current_timestep = jnp.broadcast_to( - current_timestep, (latent_model_input.shape[0],1) - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( - skip_layer_masks[i] - if skip_layer_masks is not None - else None - )) - # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) - - - if do_spatio_temporal_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_text = chunks[-2] - noise_pred_text_perturb = chunks[-1] - - if do_classifier_free_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_uncond = chunks[0] - noise_pred_text = chunks[1] - if cfg_star_rescale: - positive_flat = noise_pred_text.reshape(batch_size, -1) - negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum( - positive_flat * negative_flat, axis=1, keepdims=True - ) - squared_norm = ( - jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 - ) - alpha = dot_product / squared_norm - alpha = alpha.reshape(batch_size, 1, 1) - - noise_pred_uncond = alpha * noise_pred_uncond - noise_pred = noise_pred_uncond + guidance_scale[i] * ( - noise_pred_text - noise_pred_uncond - ) - elif do_spatio_temporal_guidance: - noise_pred = noise_pred_text - - if do_spatio_temporal_guidance: - noise_pred = noise_pred + stg_scale[i] * ( - noise_pred_text - noise_pred_text_perturb - ) - if do_rescaling and stg_scale[i] > 0.0: - noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) - noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) - - factor = noise_pred_text_std / noise_pred_std - factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) - - - noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] - latents, scheduler_state = scheduler.step( - scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() - - return latents, scheduler_state - -def adain_filter_latent( - latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 + transformer_state, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + num_inference_steps, + scheduler, + segment_ids, + encoder_attention_segment_ids, + scheduler_state, + do_classifier_free_guidance, + num_conds, + guidance_scale, + do_spatio_temporal_guidance, + stg_scale, + do_rescaling, + rescaling_scale, + batch_size, + skip_layer_masks, + cfg_star_rescale, ): - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor. - - Args: - latent (torch.Tensor): Input latents to normalize - reference_latent (torch.Tensor): The reference latents providing style statistics. - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 - - Returns: - torch.Tensor: The transformed latent tensor - """ - result = latents.clone() - - for i in range(latents.size(0)): - for c in range(latents.size(1)): - r_sd, r_mean = torch.std_mean( - reference_latents[i, c], dim=None - ) # index by original dim order - i_sd, i_mean = torch.std_mean(result[i, c], dim=None) - - result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean - - result = torch.lerp(latents, result, factor) - return result + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension + current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, + transformer_state, + current_timestep, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), + ) + + if do_spatio_temporal_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + + +def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + class LTXMultiScalePipeline: - - @classmethod - def load_latent_upsampler(cls, config): - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile( - spatial_upscaler_model_name_or_path - ): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir= "/mnt/disks/diffusionproj", - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) - latent_upsampler.eval() - return latent_upsampler - - - def _upsample_latents( - self, latest_upsampler: LatentUpsampler, latents: torch.Tensor - ): - assert latents.device == latest_upsampler.device - - latents = un_normalize_latents( - latents, self.vae, vae_per_channel_normalize=True - ) - upsampled_latents = latest_upsampler(latents) - upsampled_latents = normalize_latents( - upsampled_latents, self.vae, vae_per_channel_normalize=True - ) - return upsampled_latents - - def __init__( - self, video_pipeline: LTXVideoPipeline - ): - self.video_pipeline = video_pipeline - self.vae = video_pipeline.vae - - def __call__( - self, - height, - width, - num_frames, - output_type, - generator, - config - ) -> Any: - - latent_upsampler = self.load_latent_upsampler(config) - original_output_type = output_type - output_type = 'latent' - result = self.video_pipeline(height=height, width=width, num_frames=num_frames, - is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], - num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_block_list=config.first_pass["skip_block_list"]) - latents = result - upsampled_latents = self._upsample_latents(latent_upsampler, latents) - upsampled_latents = adain_filter_latent( - latents=upsampled_latents, reference_latents=latents - ) - - latents = upsampled_latents - output_type = original_output_type - - - result = self.video_pipeline(height=height*2, width=width*2, num_frames=num_frames, - is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], - num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_block_list=config.second_pass["skip_block_list"]) - - if original_output_type != "latent": - num_frames = result.shape[2] - videos = rearrange(result, "b c f h w -> (b f) c h w") - - videos = F.interpolate( - videos, - size=(height, width), - mode="bilinear", - align_corners=False, - ) - videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) - result = videos - - return result \ No newline at end of file + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents(latents, self.vae, vae_per_channel_normalize=True) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents(upsampled_latents, self.vae, vae_per_channel_normalize=True) + return upsampled_latents + + def __init__(self, video_pipeline: LTXVideoPipeline): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + + def __call__(self, height, width, num_frames, output_type, generator, config) -> Any: + + latent_upsampler = self.load_latent_upsampler(config) + original_output_type = output_type + output_type = "latent" + result = self.video_pipeline( + height=height, + width=width, + num_frames=num_frames, + is_video=True, + output_type=output_type, + generator=generator, + guidance_scale=config.first_pass["guidance_scale"], + stg_scale=config.first_pass["stg_scale"], + rescaling_scale=config.first_pass["rescaling_scale"], + skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], + num_inference_steps=config.first_pass["num_inference_steps"], + guidance_timesteps=config.first_pass["guidance_timesteps"], + cfg_star_rescale=config.first_pass["cfg_star_rescale"], + skip_block_list=config.first_pass["skip_block_list"], + ) + latents = result + upsampled_latents = self._upsample_latents(latent_upsampler, latents) + upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + + latents = upsampled_latents + output_type = original_output_type + + result = self.video_pipeline( + height=height * 2, + width=width * 2, + num_frames=num_frames, + is_video=True, + output_type=output_type, + latents=latents, + generator=generator, + guidance_scale=config.second_pass["guidance_scale"], + stg_scale=config.second_pass["stg_scale"], + rescaling_scale=config.second_pass["rescaling_scale"], + skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], + num_inference_steps=config.second_pass["num_inference_steps"], + guidance_timesteps=config.second_pass["guidance_timesteps"], + cfg_star_rescale=config.second_pass["cfg_star_rescale"], + skip_block_list=config.second_pass["skip_block_list"], + ) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py index 1624b81c4..b550aeea3 100644 --- a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -24,42 +24,47 @@ import jax.numpy as jnp import json from maxdiffusion.configuration_utils import ConfigMixin, register_to_config -from maxdiffusion.utils import is_scipy_available from maxdiffusion.schedulers.scheduling_utils_flax import ( CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, ) -def linear_quadratic_schedule_jax(num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None) -> jnp.ndarray: - if num_steps == 1: - return jnp.array([1.0], dtype=jnp.float32) - if linear_steps is None: - linear_steps = num_steps // 2 - linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps +def linear_quadratic_schedule_jax( + num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None +) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const - threshold_noise_step_diff = linear_steps - threshold_noise * num_steps - quadratic_steps = num_steps - linear_steps - quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) - linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) - const = quadratic_coef * (linear_steps**2) - quadratic_indices = jnp.arange(linear_steps, num_steps) - quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) - sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) - sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) - sigma_schedule = 1.0 - sigma_schedule - return sigma_schedule[:-1].astype(jnp.float32) def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: - mu_f = jnp.array(mu, dtype=jnp.float32) - sigma_f = jnp.array(sigma, dtype=jnp.float32) - return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) - + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + + def _prod_jax(iterable): - return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) - + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + + def get_normal_shift_jax( n_tokens: int, min_tokens: int = 1024, @@ -67,72 +72,71 @@ def get_normal_shift_jax( min_shift: float = 0.95, max_shift: float = 2.05, ) -> float: - m = (max_shift - min_shift) / (max_tokens - min_tokens) - b = min_shift - m * min_tokens - return m * n_tokens + b -def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: - """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" - return x[(...,) + (None,) * (target_dims - x.ndim)] + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: - if shifts.size == 0: - raise ValueError("The 'shifts' tensor must not be empty.") - if terminal <= 0 or terminal >= 1: - raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) - one_minus_z = 1.0 - shifts - # Using shifts[-1] for the last element - scale_factor = one_minus_z[-1] / (1.0 - terminal) - stretched_shifts = 1.0 - (one_minus_z / scale_factor) + return stretched_shifts - return stretched_shifts def sd3_resolution_dependent_timestep_shift_jax( samples_shape: Tuple[int, ...], timesteps: jnp.ndarray, target_shift_terminal: Optional[float] = None, ) -> jnp.ndarray: - if len(samples_shape) == 3: - _, m, _ = samples_shape - elif len(samples_shape) in [4, 5]: - m = _prod_jax(samples_shape[2:]) - else: - raise ValueError( - "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" - ) + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) - shift = get_normal_shift_jax(int(m)) - time_shifts = time_shift_jax(shift, 1.0, timesteps) + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts - if target_shift_terminal is not None: - time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) - return time_shifts - def simple_diffusion_resolution_dependent_timestep_shift_jax( samples_shape: Tuple[int, ...], timesteps: jnp.ndarray, n: int = 32 * 32, ) -> jnp.ndarray: - if len(samples_shape) == 3: - _, m, _ = samples_shape - elif len(samples_shape) in [4, 5]: - m = _prod_jax(samples_shape[2:]) - else: - raise ValueError( - "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" - ) - # Ensure m and n are float32 for calculations - m_f = jnp.array(m, dtype=jnp.float32) - n_f = jnp.array(n, dtype=jnp.float32) + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) - snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability - shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability - shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + return shifted_timesteps - return shifted_timesteps @flax.struct.dataclass class RectifiedFlowSchedulerState: @@ -145,28 +149,21 @@ class RectifiedFlowSchedulerState: num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None sigmas: Optional[jnp.ndarray] = None - - - @classmethod - def create( #need to change this! - cls, - common_state: CommonSchedulerState, - init_noise_sigma: float - ): + def create(cls, common_state: CommonSchedulerState, init_noise_sigma: float): return cls( - common = common_state, - init_noise_sigma = init_noise_sigma, - num_inference_steps = None, - timesteps = None, - sigmas = None, + common=common_state, + init_noise_sigma=init_noise_sigma, + num_inference_steps=None, + timesteps=None, + sigmas=None, ) @dataclass class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): - state: RectifiedFlowSchedulerState + state: RectifiedFlowSchedulerState class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): @@ -195,163 +192,136 @@ def __init__( dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype - def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: if common is None: common = CommonSchedulerState.create(self) init_noise_sigma = 1.0 - return RectifiedFlowSchedulerState.create(common_state = common, init_noise_sigma=init_noise_sigma) - - - def get_initial_timesteps_jax( - self, num_timesteps: int, shift: Optional[float] = None - ) -> jnp.ndarray: - if self.config.sampler == "Uniform": - return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) - elif self.config.sampler == "LinearQuadratic": - return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) - elif self.config.sampler == "Constant": - assert shift is not None, "Shift must be provided for constant time shift sampler." - return time_shift_jax( - shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) - ).astype(self.dtype) - else: - # This should be caught by __init__ but for safety - raise ValueError(f"Sampler {self.config.sampler} is not supported.") + return RectifiedFlowSchedulerState.create(common_state=common, init_noise_sigma=init_noise_sigma) + + def get_initial_timesteps_jax(self, num_timesteps: int, shift: Optional[float] = None) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax(shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype)).astype( + self.dtype + ) + else: + raise ValueError(f"Sampler {self.config.sampler} is not supported.") def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: - if self.config.shifting == "SD3": - return sd3_resolution_dependent_timestep_shift_jax( - samples_shape, timesteps, self.config.target_shift_terminal - ) - elif self.config.shifting == "SimpleDiffusion": - return simple_diffusion_resolution_dependent_timestep_shift_jax( - samples_shape, timesteps, self.config.base_resolution - ) - return timesteps - + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.target_shift_terminal) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.base_resolution) + return timesteps + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): pretrained_model_path = Path(pretrained_model_path) config = None if pretrained_model_path.is_file(): - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - configs = json.loads(metadata['config']) - config = configs["scheduler"] - + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + elif pretrained_model_path.is_dir(): - diffusers_noise_scheduler_config_path = ( - pretrained_model_path / "scheduler" / "scheduler_config.json" - ) - - if not diffusers_noise_scheduler_config_path.is_file(): - raise FileNotFoundError( - f"Scheduler config not found at {diffusers_noise_scheduler_config_path}" - ) - - with open(diffusers_noise_scheduler_config_path, "r") as f: - scheduler_config = json.load(f) - config = scheduler_config + diffusers_noise_scheduler_config_path = pretrained_model_path / "scheduler" / "scheduler_config.json" + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError(f"Scheduler config not found at {diffusers_noise_scheduler_config_path}") + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config return FlaxRectifiedFlowMultistepScheduler.from_config(config) - + def set_timesteps( - self, - state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - samples_shape: Optional[Tuple[int, ...]] = None, - timesteps: Optional[jnp.ndarray] = None, - device: Optional[str] = None, - ) -> RectifiedFlowSchedulerState: - if timesteps is not None and num_inference_steps is not None: - raise ValueError( - "You cannot provide both `timesteps` and `num_inference_steps`." - ) - - # Determine the number of inference steps if not provided - if num_inference_steps is None and timesteps is None: - raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") - - if timesteps is None: - num_inference_steps = jnp.minimum( - self.config.num_train_timesteps, num_inference_steps - ) - timesteps = self.get_initial_timesteps_jax( - num_inference_steps, shift=self.config.shift - ).astype(self.dtype) - - # Apply shifting if samples_shape is provided and shifting is configured - if samples_shape is not None: - timesteps = self.shift_timesteps_jax(samples_shape, timesteps) - else: - timesteps = jnp.asarray(timesteps, dtype=self.dtype) - num_inference_steps = len(timesteps) - - return state.replace( - timesteps=timesteps, - num_inference_steps=num_inference_steps, - sigmas=timesteps, # sigmas are the same as timesteps in RF - ) - + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError("You cannot provide both `timesteps` and `num_inference_steps`.") + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum(self.config.num_train_timesteps, num_inference_steps) + timesteps = self.get_initial_timesteps_jax(num_inference_steps, shift=self.config.shift).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + def scale_model_input( - self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - # Rectified Flow scheduler typically doesn't scale model input, returns as is. - return sample - - + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + def step( - self, - state: RectifiedFlowSchedulerState, - model_output: jnp.ndarray, - timestep: jnp.ndarray, # Can be global or per-token, but for RF it's typically global. - sample: jnp.ndarray, - return_dict: bool = True, - stochastic_sampling: bool = False, - generator: Optional[jax.random.PRNGKey] = None, - ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - t_eps = 1e-6 # Small epsilon for numerical issues - - timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) - - - if timestep.ndim == 0: - idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side='right') - current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] - lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), - timesteps_padded[current_t_idx + 1], - 0.0) - dt = timestep - lower_timestep - else: - current_t_indices = jnp.searchsorted(state.timesteps, timestep, side='right') # timesteps is decreasing - current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search - lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) - lower_timestep = timesteps_padded[lower_timestep_indices] - dt = timestep - lower_timestep - dt = append_dims_jax(dt, sample.ndim) - - - # Compute previous sample - if stochastic_sampling: - if generator is None: - raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") - broadcastable_timestep = append_dims_jax(timestep, sample.ndim) - - x0 = sample - broadcastable_timestep * model_output - next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep - - noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) - prev_sample = self.add_noise(state.common, x0, noise, next_timestep) - else: - prev_sample = sample - dt * model_output - - - if not return_dict: - return (prev_sample, state) - - return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") #noqa: F841 + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), timesteps_padded[current_t_idx + 1], 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side="right") # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) From f63a6fab9ce6c25ef9dd46ed441f02bb62608de7 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 18:59:14 +0000 Subject: [PATCH 46/69] remove init --- src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py deleted file mode 100644 index e69de29bb..000000000 From b3874f565884f111af819a6ccd1747e11239483f Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 19:03:31 +0000 Subject: [PATCH 47/69] new empty folders --- .../ltx_video/models/autoencoders/__init__.py | 16 ++++++++++++++++ .../models/ltx_video/utils/__init__.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/__init__.py b/src/maxdiffusion/models/ltx_video/utils/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file From 3e6499c039971fc967d06f20f31679f29a30573a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sun, 20 Jul 2025 23:26:57 +0000 Subject: [PATCH 48/69] downloaded files --- .../ltx_video/autoencoders/causal_conv3d.py | 63 + .../autoencoders/causal_video_autoencoder.py | 1401 +++++++++++++++++ .../ltx_video/autoencoders/conv_nd_factory.py | 90 ++ .../ltx_video/autoencoders/dual_conv3d.py | 217 +++ .../autoencoders/latent_upsampler.py | 203 +++ .../ltx_video/autoencoders/pixel_norm.py | 12 + .../ltx_video/autoencoders/pixel_shuffle.py | 33 + .../models/ltx_video/autoencoders/vae.py | 380 +++++ .../ltx_video/autoencoders/vae_encode.py | 247 +++ .../ltx_video/autoencoders/vae_torchax.py | 88 ++ .../autoencoders/video_autoencoder.py | 1045 ++++++++++++ .../ltx_video/models/autoencoders/__init__.py | 16 - .../transformers/symmetric_patchifier.py | 84 + .../utils/diffusers_config_mapping.py | 174 ++ .../ltx_video/utils/prompt_enhance_utils.py | 226 +++ .../ltx_video/utils/skip_layer_strategy.py | 8 + .../models/ltx_video/utils/torch_utils.py | 25 + .../pipelines/ltx_video/ltx_video_pipeline.py | 36 +- 18 files changed, 4317 insertions(+), 31 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py delete mode 100644 src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_utils.py diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py new file mode 100644 index 000000000..98249c2f5 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -0,0 +1,63 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py new file mode 100644 index 000000000..544c759c9 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1401 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper +from maxdiffusion.models.ltx_video.transformers.attention import Attention +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "CausalVideoAutoencoder" + ), "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError( + f"latent_log_var={latent_log_var} requires use_quant_conv=False" + ) + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels + // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_space", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } + ckpt_state_dict = { + key: value + for key, value in state_dict.items() + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + + model_keys = set(name for name, _ in self.named_modules()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample).to(torch.bfloat16) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name == "compress_all": + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.last_scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = self.conv_norm_out(sample).to(torch.bfloat16) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError( + "attention_head_dim must be less than or equal to in_channels" + ) + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, frames * height * width + ).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad( + hidden_states, (0, 0, 0, pad_len), "constant", 0 + ) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=( + None if not attention.use_tpu_flash_attention else mask + ), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, frames, height, width + ) + else: + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = ( + np.prod(stride) * in_channels // out_channels_reduction_factor + ) + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states).to(torch.bfloat16) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + + hidden_states = self.norm2(hidden_states).to(torch.bfloat16) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + + input_tensor = self.norm3(input_tensor).to(torch.bfloat16) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_demo_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape, timestep=timestep + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode( + image_latent, target_shape=image_latent.shape, timestep=timestep + ).sample + + first_frame_latent = latent[:, :, :1, :, :] + + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py new file mode 100644 index 000000000..718c69bef --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -0,0 +1,90 @@ +from typing import Tuple, Union + +import torch + +from ltx_video.models.autoencoders.dual_conv3d import DualConv3d +from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", +): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py new file mode 100644 index 000000000..dcf889296 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -0,0 +1,217 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py new file mode 100644 index 000000000..4a76bc21d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -0,0 +1,203 @@ +from typing import Optional, Union +from pathlib import Path +import os +import json + +import torch +import torch.nn as nn +from einops import rearrange +from diffusers import ConfigMixin, ModelMixin +from safetensors.torch import safe_open + +from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler + + +if __name__ == "__main__": + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py new file mode 100644 index 000000000..9bc3ea60e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py new file mode 100644 index 000000000..4e79ae284 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from einops import rearrange + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py new file mode 100644 index 000000000..5b22217c1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -0,0 +1,380 @@ +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd( + quant_dims, 2 * latent_channels, 2 * latent_channels, 1 + ) + self.post_quant_conv = make_conv_nd( + quant_dims, latent_channels, latent_channels, 1 + ) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert ( + z_sample_size % 8 == 0 or z_sample_size == 1 + ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( + 1 - z / blend_extent + ) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = ( + sizes + [z.shape[2] - sum(sizes)] + if z.shape[2] - sum(sizes) > 0 + else sizes + ) + tiles = z.split(sizes, dim=2) + moments_tiles = [ + ( + self._hw_tiled_encode(z_tile, return_dict) + if self.use_hw_tiling + else self._encode(z_tile) + ) + for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = ( + self._hw_tiled_encode(z, return_dict) + if self.use_hw_tiling + else self._encode(z) + ) + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t + * 2 + ** ( + len(self.encoder.down_blocks) + - 1 + - math.sqrt(self.encoder.patch_size) + ) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode(z, target_shape) + if self.use_hw_tiling + else self._decode(z, target_shape=target_shape, timestep=timestep) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py new file mode 100644 index 000000000..5a0aeeccf --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -0,0 +1,247 @@ +from typing import Tuple +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + + +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + latents = vae.encode(media_items).latent_dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len( + [ + block + for block in vae.encoder.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + spatial = vae.config.patch_size * 2**down_blocks + temporal = ( + vae.config.patch_size_t * 2**down_blocks + if isinstance(vae, VideoAutoencoder) + else 1 + ) + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords( + latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors( + latent_coords, scale_factors, causal_fix + ) + return pixel_coords + + +def latent_to_pixel_coords_from_factors( + latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False +) -> Tensor: + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py new file mode 100644 index 000000000..45fb33280 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py @@ -0,0 +1,88 @@ +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder +from maxdiffusion.models.ltx_video.autoencoders import causal_conv3d +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import vae_encode, vae_decode + +import jax +from torchax import interop +import os +from torchax import default_env +import jax.numpy as jnp + +# remove weight attribute to avoid error in JittableModule +# in the future, this will be fixed in ltxv public repo +delattr(causal_conv3d.CausalConv3d, 'weight') + +class TorchaxCausalVideoAutoencoder(interop.JittableModule): + def __init__(self, vae: CausalVideoAutoencoder): + super().__init__(vae, extra_jit_args=dict(static_argnames=['split_size', 'vae_per_channel_normalize'])) + + def encode(self, media_items: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True) -> jax.Array: + if media_items.ndim != 5: + raise ValueError( + f"Expected media_items to have 5 dimensions (batch, channels, frames, height, width), but got {media_items.ndim} dimensions." + ) + num_frames = media_items.shape[2] + if (num_frames - 1) % 8 != 0: + raise ValueError( + f"Expected media_items to have a number of frames that is 1 + 8 * k for some integer k, but got {num_frames} frames." + ) + with default_env(): + media_items = interop.torch_view(media_items) + + output = self.functional_call( + self._vae_encoder_inner, + params=self.params, + buffers=self.buffers, + media_items=media_items, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + + return interop.jax_view(output) + + def decode(self, latents: jax.Array, timestep: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True, is_video: bool = True) -> jax.Array: + with default_env(): + latents = interop.torch_view(latents) + timestep = interop.torch_view(timestep) + output = self.functional_call( + self._vae_decoder_inner, + params=self.params, + buffers=self.buffers, + latents=latents, + timestep=timestep, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + is_video=is_video, + ) + + return interop.jax_view(output) + + + @staticmethod + def _vae_encoder_inner(model, media_items, split_size, vae_per_channel_normalize): + return vae_encode( + media_items=media_items, + vae=model, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + @staticmethod + def _vae_decoder_inner(model, latents, timestep, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize: bool = False): + return vae_decode( + latents=latents, + vae=model, + is_video=is_video, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + timestep=timestep, + ) + + @staticmethod + def normalize_img(image): + return (image - 128) / 128 + + @staticmethod + def denormalize_img(image): + return (image * 128 + 128).clip(0, 255) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py new file mode 100644 index 000000000..3c7926c1d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -0,0 +1,1045 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from ltx_video.utils.torch_utils import Identity +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ltx_video.models.autoencoders.pixel_norm import PixelNorm +from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "VideoAutoencoder" + ), "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels + // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels + // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels + for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + * self.patch_size + ) + + def forward( + self, sample: torch.FloatTensor, return_features=False + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)( + sample, downsample_in_time=downsample_in_time + ) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block + and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, block_out_channels[0], out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward( + self, hidden_states: torch.FloatTensor, downsample_in_time + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample( + hidden_states, downsample_in_time=downsample_in_time + ) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D( + dims=dims, channels=out_channels, out_channels=out_channels + ) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, upsample_in_time=True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd( + dims, channels, out_channels, kernel_size=3, padding=1, bias=True + ) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange( + x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d + ) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py deleted file mode 100644 index cb4a6b9ce..000000000 --- a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2025 Lightricks Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This implementation is based on the Torch version available at: -# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py new file mode 100644 index 000000000..2eca32033 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..53c0082d1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py new file mode 100644 index 000000000..901051728 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -0,0 +1,226 @@ +import logging +from typing import Union, List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + +I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Keep within 150 words. +For best results, build your prompts using this structure: +Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Align to the image caption if it contradicts the user text input. +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + + +def tensor_to_pil(tensor): + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 + + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 + + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) + + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + + # Convert to PIL Image + return Image.fromarray(numpy_image) + + +def generate_cinematic_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompt: Union[str, List[str]], + conditioning_items: Optional[List] = None, + max_new_tokens: int = 256, +) -> List[str]: + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item( + first_frame_conditioning_item + ) + + assert len(first_frames) == len( + prompts + ), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts + + +def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: + frames_tensor = conditioning_item.media_item + return [ + tensor_to_pil(frames_tensor[i, :, 0, :, :]) + for i in range(frames_tensor.shape[0]) + ] + + +def _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + first_frames: List[Image.Image], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + image_captions = _generate_image_captions( + image_caption_model, image_caption_processor, first_frames + ) + + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_image_captions( + image_caption_model, + image_caption_processor, + images: List[Image.Image], + system_prompt: str = "", +) -> List[str]: + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor( + image_caption_prompts, images, return_tensors="pt" + ).to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int +) -> List[str]: + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate( + **model_inputs, max_new_tokens=max_new_tokens + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, outputs) + ] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + ) + + return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..30f9016e1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py new file mode 100644 index 000000000..991b07c36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index c4767e8ee..77922c9cf 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -14,8 +14,8 @@ import math import os from jax import Array -from maxdiffusion.models.ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler -from diffusers import AutoencoderKL +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from torchax import interop, default_env from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F @@ -30,10 +30,10 @@ import numpy as np import torch from huggingface_hub import hf_hub_download -from maxdiffusion.models.ltx_video.models.autoencoders.causal_video_autoencoder import ( +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from maxdiffusion.models.ltx_video.models.autoencoders.vae_encode import ( +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, vae_decode, @@ -51,6 +51,7 @@ from jax.sharding import Mesh from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from ...pyconfig import HyperParameters +from maxdiffusion.models.ltx_video.autoencoders.vae_torchax import TorchaxCausalVideoAutoencoder from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel @@ -71,7 +72,7 @@ def __init__( transformer: Transformer3DModel, scheduler: FlaxRectifiedFlowMultistepScheduler, scheduler_state: RectifiedFlowSchedulerState, - vae: AutoencoderKL, + vae: TorchaxCausalVideoAutoencoder, text_encoder, patchifier, tokenizer, @@ -169,8 +170,11 @@ def load_transformer(cls, config): @classmethod def load_vae(cls, ckpt_path): - vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) - return vae + torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype = torch.bfloat16) + with default_env(): + torch_vae = torch_vae.to('jax') + jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) + return jax_vae @classmethod def load_text_encoder(cls, ckpt_path): @@ -550,9 +554,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) - import pdb - - pdb.set_trace() latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor @@ -943,11 +944,16 @@ def load_latent_upsampler(cls, config): return latent_upsampler def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): - assert latents.device == latest_upsampler.device - - latents = un_normalize_latents(latents, self.vae, vae_per_channel_normalize=True) - upsampled_latents = latest_upsampler(latents) - upsampled_latents = normalize_latents(upsampled_latents, self.vae, vae_per_channel_normalize=True) + latents = jax.device_put(latents, jax.devices('tpu')[0]) + #assert latents.device == latest_upsampler.device + with default_env(): + latents = un_normalize_latents( #need to switch this out? + interop.torch_view(latents), self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) #here converted back to torch, cause upsampler in pytorch + upsampled_latents = normalize_latents( + interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True + ) return upsampled_latents def __init__(self, video_pipeline: LTXVideoPipeline): From 0b67a19a668b4d690cd5f1ca2076c2657d02fe03 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sun, 20 Jul 2025 23:45:11 +0000 Subject: [PATCH 49/69] changed upsampler --- .../pipelines/ltx_video/ltx_video_pipeline.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 77922c9cf..34195c8bb 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -762,7 +762,7 @@ def __call__( vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), timestep=decode_timestep, ) - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(torch.from_numpy(np.array(image.astype(jnp.float16))), output_type=output_type) else: image = latents @@ -983,9 +983,13 @@ def __call__(self, height, width, num_frames, output_type, generator, config) -> skip_block_list=config.first_pass["skip_block_list"], ) latents = result - upsampled_latents = self._upsample_latents(latent_upsampler, latents) - upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) - + upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here + + latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu')) + upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu')) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents + ) latents = upsampled_latents output_type = original_output_type From 443243d16408f5738d2ef40c22ec7999a65c6a97 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sun, 20 Jul 2025 23:53:44 +0000 Subject: [PATCH 50/69] kept latents as jnp --- src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 34195c8bb..f911fded7 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -731,7 +731,6 @@ def __call__( timestep=noise_cond, scheduler_state=scheduler_state, ) - latents = torch.from_numpy(np.array(latents)) latents = latents[:, num_cond_latents:] latents = self.patchifier.unpatchify( From fefe18ebe35f54842b5474ea25f3b12e34b470ee Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sun, 20 Jul 2025 23:57:54 +0000 Subject: [PATCH 51/69] prepare latents --- .../pipelines/ltx_video/ltx_video_pipeline.py | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index f911fded7..7e9e6ccc6 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -892,31 +892,36 @@ def run_inference( return latents, scheduler_state -def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor. - - Args: - latent (torch.Tensor): Input latents to normalize - reference_latent (torch.Tensor): The reference latents providing style statistics. - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 - - Returns: - torch.Tensor: The transformed latent tensor - """ - result = latents.clone() - - for i in range(latents.size(0)): - for c in range(latents.size(1)): - r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order - i_sd, i_mean = torch.std_mean(result[i, c], dim=None) - - result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean - - result = torch.lerp(latents, result, factor) - return result +def adain_filter_latent( + latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 +): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + with default_env(): + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean( + reference_latents[i, c], dim=None + ) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result class LTXMultiScalePipeline: From 4bad19603bb95828e26756f47bbcdb5243499c71 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 21 Jul 2025 00:06:53 +0000 Subject: [PATCH 52/69] save --- .../pipelines/ltx_video/ltx_video_pipeline.py | 60 ++++++++++++------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 7e9e6ccc6..5cf945e59 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -740,31 +740,45 @@ def __call__( out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), ) if output_type != "latent": - if self.vae.decoder.timestep_conditioning: - noise = torch.randn_like(latents) - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * latents.shape[0] - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [decode_noise_scale] * latents.shape[0] - - decode_timestep = torch.tensor(decode_timestep).to(latents.device) - decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[:, None, None, None, None] - latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale - else: - decode_timestep = None - image = vae_decode( - latents, - self.vae, - is_video, - vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), - timestep=decode_timestep, - ) - image = self.image_processor.postprocess(torch.from_numpy(np.array(image.astype(jnp.float16))), output_type=output_type) + if self.vae.decoder.timestep_conditioning: + noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) #move the key to outer layer + + # Convert decode_timestep to a list if it's not already one + if not isinstance(decode_timestep, (list, jnp.ndarray)): + decode_timestep = [decode_timestep] * latents.shape[0] + + # Handle decode_noise_scale + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + # Convert lists to JAX arrays + decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) + + # Reshape decode_noise_scale for broadcasting + decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) + decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) + + # Apply the noise and scale + latents = ( + latents * (1 - decode_noise_scale) + + noise * decode_noise_scale + ) + else: + decode_timestep = None + image = self.vae.decode( + latents = jax.device_put(latents, jax.devices('tpu')[0]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]), + is_video = is_video, + vae_per_channel_normalize=kwargs.get( + "vae_per_channel_normalize", True), + timestep=decode_timestep #.astype(jnp.bfloat16), + ) + image = self.postprocess_to_output_type( #swap this out! + torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type) else: - image = latents + image = latents # Offload all models From b1e5b0c47b27e7d7b818d89c4bc5fa089f7d2114 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 21 Jul 2025 00:19:40 +0000 Subject: [PATCH 53/69] fixed transformer init --- .../pipelines/ltx_video/ltx_video_pipeline.py | 2247 ++++++++++------- 1 file changed, 1320 insertions(+), 927 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 5cf945e59..547445570 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -11,24 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import math import os from jax import Array -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler -from torchax import interop, default_env +from datetime import datetime +from pathlib import Path from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor +from maxdiffusion.models.ltx_video.autoencoders.vae_torchax import TorchaxCausalVideoAutoencoder +# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +import yaml +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, + T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) + +from torchax import interop +from torchax import default_env +import imageio +import json +import numpy as np +import torch +from safetensors import safe_open +from PIL import Image from transformers import ( - FlaxT5EncoderModel, + T5EncoderModel, + T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, ) -import json -import numpy as np -import torch from huggingface_hub import hf_hub_download from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, @@ -36,759 +49,1161 @@ from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, - vae_decode, - vae_encode, un_normalize_latents, normalize_latents, ) -from diffusers.image_processor import VaeImageProcessor +# from diffusers.image_processor import VaeImageProcessor +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from math import e from types import NoneType from typing import Any, Dict +import numpy as np +import inspect import jax import jax.numpy as jnp -from jax.sharding import Mesh +from jax.sharding import Mesh, PartitionSpec as P +from typing import Optional, Union, List +import torch +from maxdiffusion.checkpointing import checkpointing_utils +from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...pyconfig import HyperParameters -from maxdiffusion.models.ltx_video.autoencoders.vae_torchax import TorchaxCausalVideoAutoencoder +# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState -from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) +from ...max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations +) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import json import functools import orbax.checkpoint as ocp +import pickle -def prepare_extra_step_kwargs(generator): - extra_step_kwargs = {} - extra_step_kwargs["generator"] = generator - return extra_step_kwargs +class PickleCheckpointHandler(ocp.CheckpointHandler): + def save(self, directory: str, item, args=None): + os.makedirs(directory, exist_ok=True) + with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: + pickle.dump(item, f) + def restore(self, directory: str, args=None): + with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: + return pickle.load(f) + + def structure(self, directory: str): + return {} # not needed for simple pickle-based handling -class LTXVideoPipeline: - def __init__( - self, - transformer: Transformer3DModel, - scheduler: FlaxRectifiedFlowMultistepScheduler, - scheduler_state: RectifiedFlowSchedulerState, - vae: TorchaxCausalVideoAutoencoder, - text_encoder, - patchifier, - tokenizer, - prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - transformer_state: Dict[Any, Any] = None, - transformer_state_shardings: Dict[Any, Any] = NoneType, - ): - self.transformer = transformer - self.devices_array = devices_array - self.mesh = mesh - self.config = config - self.p_run_inference = None - self.transformer_state = transformer_state - self.transformer_state_shardings = transformer_state_shardings - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.vae = vae - self.text_encoder = text_encoder - self.patchifier = patchifier - self.tokenizer = tokenizer - self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model - self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor - self.prompt_enhancer_llm_model = prompt_enhancer_llm_model - self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer - self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - @classmethod - def load_scheduler(cls, ckpt_path, config): - if config.sampler == "from_checkpoint" or not config.sampler: - scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) - else: - scheduler = FlaxRectifiedFlowMultistepScheduler( - sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") - ) - scheduler_state = scheduler.create_state() - - return scheduler, scheduler_state - - @classmethod - def load_transformer(cls, config): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) +def save_tensor_dict(tensor_dict, timestep): base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True - ) - ##load in jax weights checkpoint - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - return transformer, transformer_state, transformer_state_shardings - - @classmethod - def load_vae(cls, ckpt_path): - torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype = torch.bfloat16) - with default_env(): - torch_vae = torch_vae.to('jax') - jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) - return jax_vae - - @classmethod - def load_text_encoder(cls, ckpt_path): - t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) - return t5_encoder - - @classmethod - def load_tokenizer(cls, config, ckpt_path): - t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) - return t5_tokenizer - - @classmethod - def load_prompt_enhancement(cls, config): - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, - torch_dtype="bfloat16", - ) - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, - ) - return ( + local_path = os.path.join(base_dir, f"schedulerTest{timestep}") + + try: + torch.save(tensor_dict, local_path) + print(f"Dictionary of tensors saved to: {local_path}") + except Exception as e: + print(f"Error saving dictionary: {e}") + raise + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", + encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def prepare_extra_step_kwargs(generator): + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + +class LTXVideoPipeline: + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: TorchaxCausalVideoAutoencoder, + text_encoder, + patchifier, + tokenizer, prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) - - # load from pytorch version - models_dir = config.models_dir - ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" - if not os.path.isfile(ltxv_model_name_or_path): - ltxv_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=ltxv_model_name_or_path, - local_dir=models_dir, - repo_type="model", - ) - else: - ltxv_model_path = ltxv_model_name_or_path - - scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) - vae = cls.load_vae(ltxv_model_path) - vae = vae.to(torch.bfloat16) - text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) - patchifier = SymmetricPatchifier(patch_size=1) - tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) - - enhance_prompt = False - if enhance_prompt: - ( - prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer, - ) = cls.load_prompt_enhancement(config) - else: - ( - prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer, - ) = (None, None, None, None) - - return LTXVideoPipeline( - transformer=transformer, - scheduler=scheduler, - scheduler_state=scheduler_state, - vae=vae, - text_encoder=text_encoder, - patchifier=patchifier, - tokenizer=tokenizer, - prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model=prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, - devices_array=devices_array, - mesh=mesh, - config=config, - transformer_state=transformer_state, - transformer_state_shardings=transformer_state_shardings, - ) - - @classmethod - def _text_preprocessing(self, text): - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - text = text.strip() - return text - - return [process(t) for t in text] - - def denoising_step( - scheduler, - latents: Array, - noise_pred: Array, - current_timestep: Optional[Array], - conditioning_mask: Optional[Array], - t: float, - extra_step_kwargs: Dict, - t_eps: float = 1e-6, - stochastic_sampling: bool = False, - ) -> Array: - # Denoise the latents using the scheduler - denoised_latents = scheduler.step( - noise_pred, - t if current_timestep is None else current_timestep, - latents, - **extra_step_kwargs, - stochastic_sampling=stochastic_sampling, - ) - - if conditioning_mask is None: - return denoised_latents - - tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) - tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) - return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) - - def retrieve_timesteps( # currently doesn't support custom timesteps - self, - scheduler: FlaxRectifiedFlowMultistepScheduler, - latent_shape, - scheduler_state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - timesteps: Optional[List[int]] = None, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - ): - scheduler_state = scheduler.set_timesteps( - state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps - ) - timesteps = scheduler_state.timesteps - if ( - skip_initial_inference_steps < 0 - or skip_final_inference_steps < 0 - or skip_initial_inference_steps + skip_final_inference_steps >= num_inference_steps + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, ): - raise ValueError( - "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" - ) - timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] - scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) - - return scheduler_state - - def encode_prompt( - self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - text_encoder_max_tokens: int = 256, - **kwargs, - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - max_length = text_encoder_max_tokens - if prompt_embeds is None: - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," - - prompt = self._text_preprocessing(prompt) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) #noqa: F841 - - prompt_attention_mask = jnp.array(text_inputs.attention_mask) - prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype #noqa: F841 - elif self.transformer is not None: - dtype = self.transformer.dtype #noqa: F841 - else: - dtype = None #noqa: F841 - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) - prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) - prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = self._text_preprocessing(negative_prompt) - uncond_tokens = uncond_tokens * batch_size - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) - - negative_prompt_embeds = self.text_encoder( - jnp.array(uncond_input.input_ids), - attention_mask=negative_prompt_attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) - negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) - - negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) - negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - - return ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) - - def prepare_latents( ## this is in pytorch - self, - latents: torch.Tensor | None, - media_items: torch.Tensor | None, - timestep: float, - latent_shape: torch.Size | Tuple[Any, ...], - dtype: torch.dtype, - device: torch.device, - generator: torch.Generator | List[torch.Generator], - vae_per_channel_normalize: bool = True, - ): - if isinstance(generator, list) and len(generator) != latent_shape[0]: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." - ) - - # Initialize the latents with the given latents or encoded media item, if provided - assert ( - latents is None or media_items is None - ), "Cannot provide both latents and media_items. Please provide only one of the two." - - assert ( - latents is None and media_items is None or timestep < 1.0 - ), "Input media_item or latents are provided, but they will be replaced with noise." - - if media_items is not None: - latents = vae_encode( - media_items, - self.vae, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - if latents is not None: - assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." - - # For backward compatibility, generate in the "patchified" shape and rearrange - b, c, f, h, w = latent_shape - noise = randn_tensor((b, f * h * w, c), generator=generator, device=device, dtype=dtype) - noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) - - # scale the initial noise by the standard deviation required by the scheduler - # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have - - if latents is None: - latents = noise - else: - # Noise the latents to the required (first) timestep - timestep = torch.from_numpy(np.array(timestep)) - latents = timestep * noise + (1 - timestep) * latents - - return latents - - def prepare_conditioning( # removed conditioning_item logic - self, - conditioning_items, - init_latents: torch.Tensor, - num_frames: int, - height: int, - width: int, - vae_per_channel_normalize: bool = True, - generator=None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - - assert isinstance(self.vae, CausalVideoAutoencoder) - - # Patchify the updated latents and calculate their pixel coordinates - init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) - init_pixel_coords = latent_to_pixel_coords( - init_latent_coords, - self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now - causal_fix=True, - ) - - if not conditioning_items: - return init_latents, init_pixel_coords, None, 0 - - def __call__( - self, - height: int, - width: int, - num_frames: int, - negative_prompt: str = "", - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - frame_rate: int = 30, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - guidance_timesteps: Optional[List[int]] = None, - decode_timestep: Union[List[float], float] = 0.05, - decode_noise_scale: Optional[List[float]] = 0.025, - offload_to_cpu: bool = False, - enhance_prompt: bool = False, - text_encoder_max_tokens: int = 256, - num_inference_steps: int = 50, - guidance_scale: Union[float, List[float]] = 4.5, - rescaling_scale: Union[float, List[float]] = 0.7, - stg_scale: Union[float, List[float]] = 1.0, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - cfg_star_rescale: bool = False, - skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, - **kwargs, - ): - enhance_prompt = False - prompt = self.config.prompt - is_video = kwargs.get("is_video", False) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) - - latent_height = height // self.vae_scale_factor - latent_width = width // self.vae_scale_factor - latent_num_frames = num_frames // self.video_scale_factor - if isinstance(self.vae, CausalVideoAutoencoder) and is_video: - latent_num_frames += 1 - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - - latent_shape = ( - batch_size * num_images_per_prompt, - model_config["in_channels"], - latent_num_frames, - latent_height, - latent_width, - ) - scheduler_state = self.retrieve_timesteps( - self.scheduler, + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( + self.vae + ) + # self.image_processor = VaeImageProcessor( + # vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + # "Wan-AI/Wan2.1-T2V-14B-Diffusers", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( + # "Lightricks/LTX-Video", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # import pdb; pdb.set_trace() + # scheduler = FlaxRectifiedFlowMultistepScheduler( + # sampler="LinearQuadratic" + # ) + # scheduler_state = scheduler.create_state() + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", + "causal_temporal_positioning", "in_channels", "ckpt_path"] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + transformer = Transformer3DModel( + # change this sharding back + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) #move the key outwards + transformer_param_shapes = transformer.init_weights( + in_channels, jax.random.PRNGKey(0), model_config['caption_channels'], eval_only=True) + weights_init_fn = functools.partial( + transformer.init_weights, + in_channels, + jax.random.PRNGKey(0), + model_config['caption_channels'], + eval_only=True + ) + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put( + transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + + torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype = torch.bfloat16) + with default_env(): + torch_vae = torch_vae.to('jax') + jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) + return jax_vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + # text_encoder = T5EncoderModel.from_pretrained( + # ckpt_path, subfolder="text_encoder" + # ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + cpus = jax.devices("cpu") + t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + # tokenizer = T5Tokenizer.from_pretrained( + # ckpt_path, subfolder="tokenizer" + # ) + t5_tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, max_length=config.max_sequence_length, use_fast=True + ) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer( + config) + + # load from pytorch version + models_dir = "/mnt/disks/diffusionproj" #edit this! + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + + vae = cls.load_vae(ltxv_model_path) + text_encoder = cls.load_text_encoder( + config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer( + config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( + config) + else: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + """ + Perform the denoising step for the required tokens, based on the current timestep and + conditioning mask: + Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) + and will start to be denoised when the current timestep is equal or lower than their + conditioning timestep. + (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) + """ + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( #currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, latent_shape, - self.scheduler_state, - num_inference_steps, - None, - skip_initial_inference_steps, - skip_final_inference_steps, - ) - - guidance_mapping = [] - - if guidance_timesteps: - for timestep in scheduler_state.timesteps: - indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] - guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) - - if not isinstance(guidance_scale, list): - guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) - else: - guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(stg_scale, list): - stg_scale = [stg_scale] * len(scheduler_state.timesteps) - else: - stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(rescaling_scale, list): - rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) - else: - rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) - do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) - do_rescaling = any(x != 1.0 for x in rescaling_scale) - - num_conds = 1 - if do_classifier_free_guidance: - num_conds += 1 - if do_spatio_temporal_guidance: - num_conds += 1 - - is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) - - if not is_list_of_lists: - skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) - else: - new_skip_block_list = [] - for i in range(len(scheduler_state.timesteps)): - new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) - - skip_block_list = new_skip_block_list - - if do_spatio_temporal_guidance: - if skip_block_list is not None: - skip_layer_masks = [ - self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) - for skip_blocks in skip_block_list + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps # Use the original num_inference_steps here for the check + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[ + skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps ] - if enhance_prompt: - prompt = generate_cinematic_prompt( - self.prompt_enhancer_image_caption_model, - self.prompt_enhancer_image_caption_processor, - self.prompt_enhancer_llm_model, - self.prompt_enhancer_llm_tokenizer, - prompt, - None, # conditioning items set to None - max_new_tokens=text_encoder_max_tokens, - ) - - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=None, # device set to none - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_max_tokens=text_encoder_max_tokens, - ) - prompt_embeds_batch = prompt_embeds - prompt_attention_mask_batch = prompt_attention_mask - if do_classifier_free_guidance: - prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) - if do_spatio_temporal_guidance: - prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [ - prompt_attention_mask_batch, - prompt_attention_mask, - ], - axis=0, - ) - latents = self.prepare_latents( - latents=latents, - media_items=None, # set to None - timestep=scheduler_state.timesteps[0], - latent_shape=latent_shape, - dtype=None, - device=None, - generator=generator, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - - latents, pixel_coords, conditioning_mask, num_cond_latents = self.prepare_conditioning( - conditioning_items=None, - init_latents=latents, - num_frames=num_frames, - height=height, - width=width, - vae_per_channel_normalize=vae_per_channel_normalize, - generator=generator, - ) - - - pixel_coords = torch.cat([pixel_coords] * num_conds) - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - - noise_cond = jnp.ones((1, 1)) # initialize first round with this! - p_run_inference = functools.partial( - run_inference, - transformer=self.transformer, - config=self.config, - mesh=self.mesh, - fractional_cords=jnp.array(fractional_coords.to(torch.float32).detach().numpy()), - prompt_embeds=prompt_embeds_batch, - segment_ids=None, - encoder_attention_segment_ids=prompt_attention_mask_batch, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - do_classifier_free_guidance=do_classifier_free_guidance, - num_conds=num_conds, - guidance_scale=guidance_scale, - do_spatio_temporal_guidance=do_spatio_temporal_guidance, - stg_scale=stg_scale, - do_rescaling=do_rescaling, - rescaling_scale=rescaling_scale, - batch_size=batch_size, - skip_layer_masks=skip_layer_masks, - cfg_star_rescale=cfg_star_rescale, - ) - - with self.mesh: - latents, scheduler_state = p_run_inference( - transformer_state=self.transformer_state, - latents=jnp.array(latents.to(torch.float32).detach().numpy()), - timestep=noise_cond, - scheduler_state=scheduler_state, - ) - latents = latents[:, num_cond_latents:] - - latents = self.patchifier.unpatchify( - latents=latents, - output_height=latent_height, - output_width=latent_width, - out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), - ) - if output_type != "latent": - if self.vae.decoder.timestep_conditioning: - noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) #move the key to outer layer - - # Convert decode_timestep to a list if it's not already one - if not isinstance(decode_timestep, (list, jnp.ndarray)): - decode_timestep = [decode_timestep] * latents.shape[0] - - # Handle decode_noise_scale - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): - decode_noise_scale = [decode_noise_scale] * latents.shape[0] - - # Convert lists to JAX arrays - decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) - - # Reshape decode_noise_scale for broadcasting - decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) - decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) - - # Apply the noise and scale - latents = ( - latents * (1 - decode_noise_scale) + - noise * decode_noise_scale + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) + + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = ( + text_encoder_max_tokens # TPU supports only lengths multiple of 128 + ) + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + # text_enc_device = next(self.text_encoder.parameters()) + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, max_length - 1: -1] + ) + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + print(isinstance(prompt_embeds, Array)) + # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + # prompt_embeds = prompt_embeds.view( + # bs_embed * num_images_per_prompt, seq_len, -1 + # ) + prompt_embeds = jnp.reshape( + prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile( + prompt_attention_mask, (1, num_images_per_prompt)) + # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + # prompt_attention_mask = prompt_attention_mask.view( + # bs_embed * num_images_per_prompt, -1 + # ) + prompt_attention_mask = jnp.reshape( + prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance hasn't changed yet + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, + (1, num_images_per_prompt, 1) + ) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, + (batch_size * num_images_per_prompt, seq_len, -1) + ) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, + (1, num_images_per_prompt) + ) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, + (bs_embed * num_images_per_prompt, -1) + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + data_sharding = jax.sharding.NamedSharding(self.mesh, P('data')) + return ( + jax.device_put(prompt_embeds, data_sharding),#(1, 256, 4096) + jax.device_put(prompt_attention_mask, data_sharding), #1, 256 + jax.device_put(negative_prompt_embeds, data_sharding), + jax.device_put(negative_prompt_attention_mask, data_sharding) + ) + + def prepare_latents( + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = self.vae.encode( + media_itmes = media_items, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert ( + latents.shape == latent_shape + ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor( ###need to remove this! + (b, f * h * w, c), generator=generator, device=device, dtype=dtype + ) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise else: - decode_timestep = None - image = self.vae.decode( - latents = jax.device_put(latents, jax.devices('tpu')[0]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]), - is_video = is_video, - vae_per_channel_normalize=kwargs.get( - "vae_per_channel_normalize", True), - timestep=decode_timestep #.astype(jnp.bfloat16), + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, TorchaxCausalVideoAutoencoder) + + # if conditioning_items: + # batch_size, _, num_latent_frames = init_latents.shape[:3] + + # init_conditioning_mask = torch.zeros( + # init_latents[:, 0, :, :, :].shape, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents = [] + # extra_conditioning_pixel_coords = [] + # extra_conditioning_mask = [] + # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # # Process each conditioning item + # for conditioning_item in conditioning_items: + # conditioning_item = self._resize_conditioning_item( + # conditioning_item, height, width + # ) + # media_item = conditioning_item.media_item + # media_frame_number = conditioning_item.media_frame_number + # strength = conditioning_item.conditioning_strength + # assert media_item.ndim == 5 # (b, c, f, h, w) + # b, c, n_frames, h, w = media_item.shape + # assert ( + # height == h and width == w + # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + # assert n_frames % 8 == 1 + # assert ( + # media_frame_number >= 0 + # and media_frame_number + n_frames <= num_frames + # ) + + # # Encode the provided conditioning media item + # media_item_latents = vae_encode( + # media_item.to(dtype=self.vae.dtype, device=self.vae.device), + # self.vae, + # vae_per_channel_normalize=vae_per_channel_normalize, + # ).to(dtype=init_latents.dtype) + + # # Handle the different conditioning cases + # if media_frame_number == 0: + # # Get the target spatial position of the latent conditioning item + # media_item_latents, l_x, l_y = self._get_latent_spatial_position( + # media_item_latents, + # conditioning_item, + # height, + # width, + # strip_latent_border=True, + # ) + # b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # # First frame or sequence - just update the initial noise latents and the mask + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + # torch.lerp( + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + # media_item_latents, + # strength, + # ) + # ) + # init_conditioning_mask[ + # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + # ] = strength + # else: + # # Non-first frame or sequence + # if n_frames > 1: + # # Handle non-first sequence. + # # Encoded latents are either fully consumed, or the prefix is handled separately below. + # ( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # ) = self._handle_non_first_conditioning_sequence( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # media_frame_number, + # strength, + # ) + + # # Single frame or sequence-prefix latents + # if media_item_latents is not None: + # noise = randn_tensor( + # media_item_latents.shape, + # generator=generator, + # device=media_item_latents.device, + # dtype=media_item_latents.dtype, + # ) + + # media_item_latents = torch.lerp( + # noise, media_item_latents, strength + # ) + + # # Patchify the extra conditioning latents and calculate their pixel coordinates + # media_item_latents, latent_coords = self.patchifier.patchify( + # latents=media_item_latents + # ) + # pixel_coords = latent_to_pixel_coords( + # latent_coords, + # self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, + # ) + + # # Update the frame numbers to match the target frame number + # pixel_coords[:, 0] += media_frame_number + # extra_conditioning_num_latents += media_item_latents.shape[1] + + # conditioning_mask = torch.full( + # media_item_latents.shape[:2], + # strength, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents.append(media_item_latents) + # extra_conditioning_pixel_coords.append(pixel_coords) + # extra_conditioning_mask.append(conditioning_mask) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify( + latents=init_latents + ) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True + + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + # init_conditioning_mask, _ = self.patchifier.patchify( + # latents=init_conditioning_mask.unsqueeze(1) + # ) + # init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + # if extra_conditioning_latents: + # # Stack the extra conditioning latents, pixel coordinates and mask + # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + # init_pixel_coords = torch.cat( + # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + # ) + # init_conditioning_mask = torch.cat( + # [*extra_conditioning_mask, init_conditioning_mask], dim=1 + # ) + + # if self.transformer.use_tpu_flash_attention: + # # When flash attention is used, keep the original number of tokens by removing + # # tokens from the end. + # init_latents = init_latents[:, :-extra_conditioning_num_latents] + # init_pixel_coords = init_pixel_coords[ + # :, :, :-extra_conditioning_num_latents + # ] + # init_conditioning_mask = init_conditioning_mask[ + # :, :-extra_conditioning_num_latents + # ] + + # return ( + # init_latents, + # init_pixel_coords, + # init_conditioning_mask, + # extra_conditioning_num_latents, + # ) + + # change the paramters of these, currently pass in dummy inputs + + #copied from image_processor ?? change?? move to a seperate file???? + + def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. + """ + return (images * 0.5 + 0.5).clamp(0, 1) + + def _denormalize_conditionally( + self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None + ) -> torch.Tensor: + r""" + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) + + return torch.stack( + [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])] + ) + #same as diffusers image_processor.postprocess + + def postprocess_to_output_type(self, image, output_type): #support latent & pt + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + + if output_type not in ["latent", "pt", "np", "pil"]: + output_type = "np" + + if output_type == "latent": + return image + image = self._denormalize_conditionally(image, None) #do denormalize set to none + if output_type == "pt": + return image + + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get( + "vae_per_channel_normalize", True) + + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, TorchaxCausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) + + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [ + i for i, val in enumerate(guidance_timesteps) if val <= timestep + ] + guidance_mapping.append( + indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) + ) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask( + batch_size, num_conds, num_conds - 1, skip_blocks + ) + for skip_blocks in skip_block_list + ] + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate( + [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension + ) + prompt_attention_mask_batch = jnp.concatenate( + [negative_prompt_attention_mask, prompt_attention_mask], axis=0 + ) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = ( + self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) ) - image = self.postprocess_to_output_type( #swap this out! - torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type) - else: - image = latents + extra_step_kwargs = prepare_extra_step_kwargs( + generator=jax.random.PRNGKey(0)) + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + # if not isinstance(guidance_scale, List): + # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) + # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + # latents = jax.device_put(example_inputs["latents"], data_sharding) + # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + noise_cond = jnp.ones( # initialize first round with this! + (4, 1) + ) - # Offload all models + # # noise_cond = None + # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" + # tensor_dict = torch.load(saved_tensor_path) + + # for key, value in tensor_dict.items(): + # if value is not None: + # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) + # example_inputs = tensor_dict + # latents = jax.device_put(example_inputs["latent_model_input"]) + # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) + # fractional_coords = jax.device_put(example_inputs["indices_grid"]) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) + # segment_ids = None + # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + + # #only run this for the first time! + # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) + # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken + # scheduler_state = self.scheduler_state + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = jax.jit( + # functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # ), + # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct + # out_shardings=None, + # ) + segment_ids = None + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # # guidance_scale=guidance_scale + # ) + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array( + fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling = do_rescaling, + rescaling_scale = rescaling_scale, + batch_size=batch_size, + skip_layer_masks = skip_layer_masks, + skip_layer_strategy = skip_layer_strategy, + cfg_star_rescale = cfg_star_rescale + ) - if not return_dict: - return (image,) + with self.mesh: + latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( + # add scheduler state back in + torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) + # latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] + // math.prod(self.patchifier.patch_size), + ) + + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) #move the key to outer layer + + # Convert decode_timestep to a list if it's not already one + if not isinstance(decode_timestep, (list, jnp.ndarray)): + decode_timestep = [decode_timestep] * latents.shape[0] + + # Handle decode_noise_scale + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + # Convert lists to JAX arrays + decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) + + # Reshape decode_noise_scale for broadcasting + decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) + decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) + + # Apply the noise and scale + latents = ( + latents * (1 - decode_noise_scale) + + noise * decode_noise_scale + ) + else: + decode_timestep = None + image = self.vae.decode( + latents = jax.device_put(latents, jax.devices('tpu')[0]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]), + is_video = is_video, + vae_per_channel_normalize=kwargs.get( + "vae_per_channel_normalize", True), + timestep=decode_timestep #.astype(jnp.bfloat16), + ) + image = self.postprocess_to_output_type( #swap this out! + torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type) - return image + else: + image = latents + # Offload all models -def transformer_forward_pass( + if not return_dict: + return (image,) + + return image + # save states here + + +def transformer_forward_pass( # need to jit this? wan didnt latents, state, noise_cond, @@ -798,113 +1213,120 @@ def transformer_forward_pass( segment_ids, encoder_attention_segment_ids, skip_layer_mask, + skip_layer_strategy, ): - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - skip_layer_mask=skip_layer_mask, - ) - return noise_pred, state + + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy + ) # need .param here? + return noise_pred, state def run_inference( - transformer_state, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - num_inference_steps, - scheduler, - segment_ids, - encoder_attention_segment_ids, - scheduler_state, - do_classifier_free_guidance, - num_conds, - guidance_scale, - do_spatio_temporal_guidance, - stg_scale, - do_rescaling, - rescaling_scale, - batch_size, - skip_layer_masks, - cfg_star_rescale, + transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale ): - for i, t in enumerate(scheduler_state.timesteps): - current_timestep = t - latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents - if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - if isinstance(current_timestep, float): - dtype = jnp.float32 - else: - dtype = jnp.int32 - - current_timestep = jnp.array( - [current_timestep], - dtype=dtype, - ) - elif current_timestep.ndim == 0: - current_timestep = jnp.expand_dims(current_timestep, axis=0) - - # Broadcast to batch dimension - current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) - - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, - transformer_state, - current_timestep, - transformer, - fractional_cords, - prompt_embeds, - segment_ids, - encoder_attention_segment_ids, - skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), - ) - - if do_spatio_temporal_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_text = chunks[-2] - noise_pred_text_perturb = chunks[-1] - - if do_classifier_free_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_uncond = chunks[0] - noise_pred_text = chunks[1] - if cfg_star_rescale: - positive_flat = noise_pred_text.reshape(batch_size, -1) - negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) - squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 - alpha = dot_product / squared_norm - alpha = alpha.reshape(batch_size, 1, 1) - - noise_pred_uncond = alpha * noise_pred_uncond - noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) - elif do_spatio_temporal_guidance: - noise_pred = noise_pred_text - - if do_spatio_temporal_guidance: - noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) - if do_rescaling and stg_scale[i] > 0.0: - noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) - noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) - - factor = noise_pred_text_std / noise_pred_std - factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) - - noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() - - return latents, scheduler_state + # do_classifier_free_guidance = guidance_scale > 1.0 + # for step in range(num_inference_steps): + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = ( + jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + ) + # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) + # timestep = jnp.broadcast_to( + # t, (latent_model_input.shape[0],) + # ).reshape(-1, 1) + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + # Determine the correct dtype based on the device (similar to PyTorch) + is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + # Broadcast to batch dimension (compatible with ONNX/Core ML) + current_timestep = jnp.broadcast_to( + current_timestep, (latent_model_input.shape[0],1) + ) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( + skip_layer_masks[i] + if skip_layer_masks is not None + else None + ), skip_layer_strategy = skip_layer_strategy) + # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) + + # # latents = self.denoising + # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated + if do_spatio_temporal_guidance: + # JAX uses jnp.split for splitting along an axis. + # Equivalent to .chunk() for equal-sized chunks + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum( #(1, 1) + positive_flat * negative_flat, axis=1, keepdims=True + ) + squared_norm = ( #(1, 1) + jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + ) + alpha = dot_product / squared_norm #might need to reshape this (1, 1) + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) + noise_pred = noise_pred_uncond + guidance_scale[i] * ( + noise_pred_text - noise_pred_uncond + ) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] # JAX slicing is similar + latents, scheduler_state = scheduler.step( + scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state def adain_filter_latent( latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 @@ -937,110 +1359,81 @@ def adain_filter_latent( result = torch.lerp(latents, result, factor) return result +class LTXMultiScalePipeline: ##figure these methods out + def _upsample_latents( + self, latest_upsampler: LatentUpsampler, latents: torch.Tensor + ): + latents = jax.device_put(latents, jax.devices('tpu')[0]) + #assert latents.device == latest_upsampler.device + with default_env(): + latents = un_normalize_latents( #need to switch this out? + interop.torch_view(latents), self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) #here converted back to torch, cause upsampler in pytorch + upsampled_latents = normalize_latents( + interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents -class LTXMultiScalePipeline: - - @classmethod - def load_latent_upsampler(cls, config): - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir=config.models_dir, - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) - latent_upsampler.eval() - return latent_upsampler - - def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): - latents = jax.device_put(latents, jax.devices('tpu')[0]) - #assert latents.device == latest_upsampler.device - with default_env(): - latents = un_normalize_latents( #need to switch this out? - interop.torch_view(latents), self.vae, vae_per_channel_normalize=True - ) - upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) #here converted back to torch, cause upsampler in pytorch - upsampled_latents = normalize_latents( - interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True + def __init__( + self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + ): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + self.latent_upsampler = latent_upsampler + + def __call__( + self, + height, + width, + num_frames, + is_video, + output_type, + generator, + config + ) -> Any: + + original_output_type = output_type + original_width = width + original_height = height + x_width = int(width * config.downscale_factor) + downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) + x_height = int(height * config.downscale_factor) + downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) + # #use original height and width here to see + output_type = 'latent' + result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, + is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], + num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) + latents = result + print("done") + upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) #convert back to pytorch here + ##maybe change this? + latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu')) + upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu')) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents ) - return upsampled_latents - - def __init__(self, video_pipeline: LTXVideoPipeline): - self.video_pipeline = video_pipeline - self.vae = video_pipeline.vae - - def __call__(self, height, width, num_frames, output_type, generator, config) -> Any: - - latent_upsampler = self.load_latent_upsampler(config) - original_output_type = output_type - output_type = "latent" - result = self.video_pipeline( - height=height, - width=width, - num_frames=num_frames, - is_video=True, - output_type=output_type, - generator=generator, - guidance_scale=config.first_pass["guidance_scale"], - stg_scale=config.first_pass["stg_scale"], - rescaling_scale=config.first_pass["rescaling_scale"], - skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], - skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], - num_inference_steps=config.first_pass["num_inference_steps"], - guidance_timesteps=config.first_pass["guidance_timesteps"], - cfg_star_rescale=config.first_pass["cfg_star_rescale"], - skip_block_list=config.first_pass["skip_block_list"], - ) - latents = result - upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here - - latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu')) - upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu')) - upsampled_latents = adain_filter_latent( - latents=upsampled_latents, reference_latents=latents - ) - latents = upsampled_latents - output_type = original_output_type - - result = self.video_pipeline( - height=height * 2, - width=width * 2, - num_frames=num_frames, - is_video=True, - output_type=output_type, - latents=latents, - generator=generator, - guidance_scale=config.second_pass["guidance_scale"], - stg_scale=config.second_pass["stg_scale"], - rescaling_scale=config.second_pass["rescaling_scale"], - skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], - skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], - num_inference_steps=config.second_pass["num_inference_steps"], - guidance_timesteps=config.second_pass["guidance_timesteps"], - cfg_star_rescale=config.second_pass["cfg_star_rescale"], - skip_block_list=config.second_pass["skip_block_list"], - ) - - if original_output_type != "latent": - num_frames = result.shape[2] - videos = rearrange(result, "b c f h w -> (b f) c h w") - - videos = F.interpolate( - videos, - size=(height, width), - mode="bilinear", - align_corners=False, - ) - videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) - result = videos + latents = upsampled_latents + output_type = original_output_type + width = downscaled_width * 2 + height = downscaled_height * 2 + # import pdb; pdb.set_trace() + result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, + is_video=True, output_type=original_output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], + num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(original_height, original_width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos - return result + return result \ No newline at end of file From fd9eb11227afd41063ac6bb9079c0bada428d00b Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 21 Jul 2025 00:35:29 +0000 Subject: [PATCH 54/69] error attribute weight already exist --- src/maxdiffusion/configs/ltx_video.yml | 2 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 31 ++++++++++++++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 0fdbe7f9f..6d5759764 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -65,7 +65,7 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_heads', 'fsdp'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', 'data'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 547445570..cfb8a72e5 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -27,6 +27,7 @@ from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from torchax import interop from torchax import default_env import imageio @@ -1360,6 +1361,28 @@ def adain_filter_latent( return result class LTXMultiScalePipeline: ##figure these methods out + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + def _upsample_latents( self, latest_upsampler: LatentUpsampler, latents: torch.Tensor ): @@ -1376,23 +1399,20 @@ def _upsample_latents( return upsampled_latents def __init__( - self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + self, video_pipeline: LTXVideoPipeline ): self.video_pipeline = video_pipeline self.vae = video_pipeline.vae - self.latent_upsampler = latent_upsampler def __call__( self, height, width, num_frames, - is_video, output_type, generator, config ) -> Any: - original_output_type = output_type original_width = width original_height = height @@ -1407,7 +1427,8 @@ def __call__( num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) latents = result print("done") - upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) #convert back to pytorch here + latent_upsampler = self.load_latent_upsampler(config) + upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here ##maybe change this? latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu')) upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu')) From 4c9be69b59e9993a5725f6f1935b0906f1760fe4 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 21 Jul 2025 23:45:23 +0000 Subject: [PATCH 55/69] baseline pipeline cleaned --- | 0 src/maxdiffusion/configs/ltx_video.yml | 4 +- src/maxdiffusion/generate_ltx_video.py | 55 +- .../models/ltx_video/autoencoders/__init__.py | 0 .../ltx_video/autoencoders/causal_conv3d.py | 103 +- .../autoencoders/causal_video_autoencoder.py | 2369 ++++++++--------- .../ltx_video/autoencoders/conv_nd_factory.py | 120 +- .../ltx_video/autoencoders/dual_conv3d.py | 385 ++- .../autoencoders/latent_upsampler.py | 357 ++- .../ltx_video/autoencoders/pixel_norm.py | 13 +- .../ltx_video/autoencoders/pixel_shuffle.py | 55 +- .../models/ltx_video/autoencoders/vae.py | 686 +++-- .../ltx_video/autoencoders/vae_encode.py | 370 ++- .../ltx_video/autoencoders/vae_torchax.py | 137 +- .../autoencoders/video_autoencoder.py | 1745 ++++++------ .../ltx_video/transformers/attention.py | 2 +- .../transformers/symmetric_patchifier.py | 130 +- .../ltx_video/transformers/transformer3d.py | 2 +- .../models/ltx_video/utils/__init__.py | 2 +- .../utils/diffusers_config_mapping.py | 16 +- .../ltx_video/utils/prompt_enhance_utils.py | 210 +- .../ltx_video/utils/skip_layer_strategy.py | 8 +- .../models/ltx_video/utils/torch_utils.py | 28 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 2316 +++++++--------- .../schedulers/scheduling_rectified_flow.py | 2 +- 25 files changed, 4221 insertions(+), 4894 deletions(-) create mode 100644 create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/__init__.py diff --git a/ b/ new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 6d5759764..1259d2dab 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -10,6 +10,7 @@ activations_dtype: 'bfloat16' run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False + #Checkpoints text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" @@ -23,12 +24,13 @@ sampler: "from_checkpoint" # Generation parameters -pipeline_type: multi-scale +pipeline_type: None prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." height: 512 width: 512 num_frames: 88 #344 flow_shift: 5.0 +fps: 24 downscale_factor: 0.6666666 spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" prompt_enhancement_words_threshold: 120 diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index f7d7e6d03..5383ab3cd 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,9 +4,16 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig +import jax.numpy as jnp +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from huggingface_hub import hf_hub_download import imageio from datetime import datetime +from maxdiffusion.utils import export_to_video + import os +import json import torch from pathlib import Path @@ -55,6 +62,13 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: return "-".join(result) +def create_latent_upsampler(latent_upsampler_model_path: str, device: str): + latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) + latent_upsampler.to(device) + latent_upsampler.eval() + return latent_upsampler + + def get_unique_filename( base: str, ext: str, @@ -80,15 +94,52 @@ def run(config): width_padded = ((config.width - 1) // 32 + 1) * 32 num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 padding = calculate_padding(config.height, config.width, height_padded, width_padded) + # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold + # prompt_word_count = len(config.prompt.split()) + # enhance_prompt = ( + # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold + # ) - seed = 10 + seed = 10 # change this, generator in pytorch, used in prepare_latents generator = torch.Generator().manual_seed(seed) pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False) - pipeline = LTXMultiScalePipeline(pipeline) + if config.pipeline_type == "multi-scale": # move this to pipeline file?? + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir="/mnt/disks/diffusionproj", + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = create_latent_upsampler(spatial_upscaler_model_path, "cpu") # device set to cpu for now + pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) + stg_mode = config.stg_mode + if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": + skip_layer_strategy = SkipLayerStrategy.AttentionValues + elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": + skip_layer_strategy = SkipLayerStrategy.AttentionSkip + elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": + skip_layer_strategy = SkipLayerStrategy.Residual + elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": + skip_layer_strategy = SkipLayerStrategy.TransformerBlock + else: + raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") + # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, + # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, + # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images images = pipeline( height=height_padded, width=width_padded, num_frames=num_frames_padded, + is_video=True, output_type="pt", generator=generator, config=config, diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py index 98249c2f5..8797880d3 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -5,59 +5,54 @@ class CausalConv3d(nn.Module): - def __init__( - self, + + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( in_channels, out_channels, - kernel_size: int = 3, - stride: Union[int, Tuple[int]] = 1, - dilation: int = 1, - groups: int = 1, - spatial_padding_mode: str = "zeros", - **kwargs, - ): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - - kernel_size = (kernel_size, kernel_size, kernel_size) - self.time_kernel_size = kernel_size[0] - - dilation = (dilation, 1, 1) - - height_pad = kernel_size[1] // 2 - width_pad = kernel_size[2] // 2 - padding = (0, height_pad, width_pad) - - self.conv = nn.Conv3d( - in_channels, - out_channels, - kernel_size, - stride=stride, - dilation=dilation, - padding=padding, - padding_mode=spatial_padding_mode, - groups=groups, - ) - - def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x - - @property - def weight(self): - return self.conv.weight + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py index 544c759c9..439953ed6 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -31,1371 +31,1254 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_name_or_path = Path(pretrained_model_name_or_path) - if ( - pretrained_model_name_or_path.is_dir() - and (pretrained_model_name_or_path / "autoencoder.pth").exists() - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - std_of_means = data_dict["std-of-means"] - mean_of_means = data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( - std_of_means - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( - mean_of_means - ) - - elif pretrained_model_name_or_path.is_dir(): - config_path = pretrained_model_name_or_path / "vae" / "config.json" - with open(config_path, "r") as f: - config = make_hashable_key(json.load(f)) - - assert config in diffusers_and_ours_config_mapping, ( - "Provided diffusers checkpoint config for VAE is not suppported. " - "We only support diffusers configs found in Lightricks/LTX-Video." - ) - - config = diffusers_and_ours_config_mapping[config] - - state_dict_path = ( - pretrained_model_name_or_path - / "vae" - / "diffusion_pytorch_model.safetensors" - ) - - state_dict = {} - with safe_open(state_dict_path, framework="pt", device="cpu") as f: - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - for key in list(state_dict.keys()): - new_key = key - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - - state_dict[new_key] = state_dict.pop(key) - - elif pretrained_model_name_or_path.is_file() and str( - pretrained_model_name_or_path - ).endswith(".safetensors"): - state_dict = {} - with safe_open( - pretrained_model_name_or_path, framework="pt", device="cpu" - ) as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - configs = json.loads(metadata["config"]) - config = configs["vae"] - - video_vae = cls.from_config(config) - if "torch_dtype" in kwargs: - video_vae.to(kwargs["torch_dtype"]) - video_vae.load_state_dict(state_dict) - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "CausalVideoAutoencoder" - ), "config must have _class_name=CausalVideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - normalize_latent_channels = config.get("normalize_latent_channels", False) - - if use_quant_conv and latent_log_var in ["uniform", "constant"]: - raise ValueError( - f"latent_log_var={latent_log_var} requires use_quant_conv=False" - ) - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - base_channels=config.get("encoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - causal=config.get("causal_decoder", False), - timestep_conditioning=config.get("timestep_conditioning", False), - base_channels=config.get("decoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - dims = config["dims"] - return CausalVideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - normalize_latent_channels=normalize_latent_channels, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="CausalVideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, - out_channels=self.decoder.conv_out.out_channels - // self.decoder.patch_size**2, - latent_channels=self.decoder.conv_in.in_channels, - encoder_blocks=self.encoder.blocks_desc, - decoder_blocks=self.decoder.blocks_desc, - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - causal_decoder=self.decoder.causal, - timestep_conditioning=self.decoder.timestep_conditioning, - normalize_latent_channels=self.normalize_latent_channels, - ) - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def spatial_downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.encoder.blocks_desc - if block[0] - in [ - "compress_space", - "compress_all", - "compress_all_res", - "compress_space_res", - ] - ] - ) - * self.encoder.patch_size - ) + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if pretrained_model_name_or_path.is_dir() and (pretrained_model_name_or_path / "autoencoder.pth").exists(): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json" + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)} + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = std_of_means + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = mean_of_means + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = pretrained_model_name_or_path / "vae" / "diffusion_pytorch_model.safetensors" + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str(pretrained_model_name_or_path).endswith(".safetensors"): + state_dict = {} + with safe_open(pretrained_model_name_or_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none") + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError(f"latent_log_var={latent_log_var} requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 - @property - def temporal_downscale_factor(self): - return 2 ** len( + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( [ block for block in self.encoder.blocks_desc if block[0] in [ - "compress_time", + "compress_space", "compress_all", "compress_all_res", "compress_space_res", ] ] ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - if any([key.startswith("vae.") for key in state_dict.keys()]): - state_dict = { - key.replace("vae.", ""): value - for key, value in state_dict.items() - if key.startswith("vae.") - } - ckpt_state_dict = { - key: value - for key, value in state_dict.items() - if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - - model_keys = set(name for name, _ in self.named_modules()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - converted_state_dict = {} - for key, value in ckpt_state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - key_prefix = ".".join(key.split(".")[:-1]) - if "norm" in key and key_prefix not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - data_dict = { - key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value - for key, value in state_dict.items() - if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - if len(data_dict) > 0: - self.register_buffer("std_of_means", data_dict["std-of-means"]) - self.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) + def to_json_string(self) -> str: + import json - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer + return json.dumps(self.config.__dict__) - def set_use_tpu_flash_attention(self): - for block in self.decoder.up_blocks: - if isinstance(block, UNetMidBlock3D) and block.attention_blocks: - for attention_block in block.attention_blocks: - attention_block.set_use_tpu_flash_attention() + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = {key.replace("vae.", ""): value for key, value in state_dict.items() if key.startswith("vae.")} + ckpt_state_dict = {key: value for key, value in state_dict.items() if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)} + model_keys = set(name for name, _ in self.named_modules()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info(f"Removing key {key} from state_dict as it is not present in the model") + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. - """ - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - self.blocks_desc = blocks - - in_channels = in_channels * patch_size**2 - output_channel = base_channels - - self.conv_in = make_conv_nd( +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( dims=dims, - in_channels=in_channels, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, out_channels=output_channel, kernel_size=3, - stride=1, - padding=1, + stride=(2, 1, 1), causal=True, spatial_padding_mode=spatial_padding_mode, ) - - self.down_blocks = nn.ModuleList([]) - - for block_name, block_params in blocks: - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 1, 1), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(1, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown block: {block_name}") - - self.down_blocks.append(block) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var == "constant": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, - output_channel, - conv_out_channels, - 3, - padding=1, + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), causal=True, spatial_padding_mode=spatial_padding_mode, ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)(sample) - - sample = self.conv_norm_out(sample).to(torch.bfloat16) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - elif self.latent_log_var == "constant": - sample = sample[:, :-1, ...] - approx_ln_0 = ( - -30 - ) # this is the minimal clamp value in DiagonalGaussianDistribution objects - sample = torch.cat( - [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], - dim=1, - ) - - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - causal (`bool`, *optional*, defaults to `True`): - Whether to use causal convolutions or not. - """ - - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - causal: bool = True, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.layers_per_block = layers_per_block - out_channels = out_channels * patch_size**2 - self.causal = causal - self.blocks_desc = blocks - - # Compute output channel to be product of all channel-multiplier blocks - output_channel = base_channels - for block_name, block_params in list(reversed(blocks)): - block_params = block_params if isinstance(block_params, dict) else {} - if block_name == "res_x_y": - output_channel = output_channel * block_params.get("multiplier", 2) - if block_name == "compress_all": - output_channel = output_channel * block_params.get("multiplier", 1) - - self.conv_in = make_conv_nd( - dims, - in_channels, - output_channel, + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, kernel_size=3, - stride=1, - padding=1, + stride=(2, 2, 2), causal=True, spatial_padding_mode=spatial_padding_mode, ) - - self.up_blocks = nn.ModuleList([]) - - for block_name, block_params in list(reversed(blocks)): - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = output_channel // block_params.get("multiplier", 2) - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=False, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - output_channel = output_channel // block_params.get("multiplier", 1) - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 2, 2), - residual=block_params.get("residual", False), - out_channels_reduction_factor=block_params.get("multiplier", 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown layer: {block_name}") - - self.up_blocks.append(block) - - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, - output_channel, - out_channels, - 3, - padding=1, + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), causal=True, spatial_padding_mode=spatial_padding_mode, ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample).to(torch.bfloat16) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample - self.gradient_checkpointing = False - - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter( - torch.tensor(1000.0, dtype=torch.float32) - ) - self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - output_channel * 2, 0 - ) - self.last_scale_shift_table = nn.Parameter( - torch.randn(2, output_channel) / output_channel**0.5 - ) - - def forward( - self, - sample: torch.FloatTensor, - target_shape, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - batch_size = sample.shape[0] - - sample = self.conv_in(sample, causal=self.causal) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name == "compress_all": + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") - sample = sample.to(upscale_dtype) - - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - sample = self.conv_norm_out(sample).to(torch.bfloat16) - - if self.timestep_conditioning: - embedded_timestep = self.last_time_embedder( - timestep=scaled_timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=sample.shape[0], - hidden_dtype=sample.dtype, - ) - embedded_timestep = embedded_timestep.view( - batch_size, embedded_timestep.shape[-1], 1, 1, 1 - ) - ada_values = self.last_scale_shift_table[ - None, ..., None, None, None - ] + embedded_timestep.reshape( - batch_size, - 2, - -1, - embedded_timestep.shape[-3], - embedded_timestep.shape[-2], - embedded_timestep.shape[-1], - ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift - - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) + self.up_blocks.append(block) - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - return sample + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.last_scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)(sample, causal=self.causal, timestep=scaled_timestep) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = self.conv_norm_out(sample).to(torch.bfloat16) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - inject_noise (`bool`, *optional*, defaults to `False`): - Whether to inject noise into the hidden states. - timestep_conditioning (`bool`, *optional*, defaults to `False`): - Whether to condition the hidden states on the timestep. - attention_head_dim (`int`, *optional*, defaults to -1): - The dimension of the attention head. If -1, no attention is used. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - attention_head_dim: int = -1, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - in_channels * 4, 0 + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - inject_noise=inject_noise, - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - for _ in range(num_layers) - ] + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError("attention_head_dim must be less than or equal to in_channels") + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, frames * height * width).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), "constant", 0) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=(None if not attention.use_tpu_flash_attention else mask), ) - self.attention_blocks = None - - if attention_head_dim > 0: - if attention_head_dim > in_channels: - raise ValueError( - "attention_head_dim must be less than or equal to in_channels" - ) - - self.attention_blocks = nn.ModuleList( - [ - Attention( - query_dim=in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - bias=True, - out_bias=True, - qk_norm="rms_norm", - residual_connection=True, - ) - for _ in range(num_layers) - ] - ) + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] - def forward( - self, - hidden_states: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - timestep_embed = None - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - batch_size = hidden_states.shape[0] - timestep_embed = self.time_embedder( - timestep=timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - timestep_embed = timestep_embed.view( - batch_size, timestep_embed.shape[-1], 1, 1, 1 - ) + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, frames, height, width) + else: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed) - if self.attention_blocks: - for resnet, attention in zip(self.res_blocks, self.attention_blocks): - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - # Reshape the hidden states to be (batch_size, frames * height * width, channel) - batch_size, channel, frames, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, frames * height * width - ).transpose(1, 2) - - if attention.use_tpu_flash_attention: - # Pad the second dimension to be divisible by block_k_major (block in flash attention) - seq_len = hidden_states.shape[1] - block_k_major = 512 - pad_len = (block_k_major - seq_len % block_k_major) % block_k_major - if pad_len > 0: - hidden_states = F.pad( - hidden_states, (0, 0, 0, pad_len), "constant", 0 - ) - - # Create a mask with ones for the original sequence length and zeros for the padded indexes - mask = torch.ones( - (hidden_states.shape[0], seq_len), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if pad_len > 0: - mask = F.pad(mask, (0, pad_len), "constant", 0) - - hidden_states = attention( - hidden_states, - attention_mask=( - None if not attention.use_tpu_flash_attention else mask - ), - ) - - if attention.use_tpu_flash_attention: - # Remove the padding - if pad_len > 0: - hidden_states = hidden_states[:, :-pad_len, :] - - # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, frames, height, width - ) - else: - for resnet in self.res_blocks: - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - return hidden_states + return hidden_states class SpaceToDepthDownsample(nn.Module): - def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): - super().__init__() - self.stride = stride - self.group_size = in_channels * np.prod(stride) // out_channels - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels // np.prod(stride), - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - def forward(self, x, causal: bool = True): - if self.stride[0] == 2: - x = torch.cat( - [x[:, :, :1, :, :], x], dim=2 - ) # duplicate first frames for padding - - # skip connection - x_in = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) - x_in = x_in.mean(dim=2) - - # conv - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - x = x + x_in + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in - return x + return x class DepthToSpaceUpsample(nn.Module): - def __init__( - self, - dims, - in_channels, - stride, - residual=False, - out_channels_reduction_factor=1, - spatial_padding_mode="zeros", - ): - super().__init__() - self.stride = stride - self.out_channels = ( - np.prod(stride) * in_channels // out_channels_reduction_factor - ) - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) - self.residual = residual - self.out_channels_reduction_factor = out_channels_reduction_factor - - def forward(self, x, causal: bool = True): - if self.residual: - # Reshape and duplicate the input to match the output shape - x_in = self.pixel_shuffle(x) - num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor - x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: - x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = self.pixel_shuffle(x) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = np.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x -class LayerNorm(nn.Module): - def __init__(self, dim, eps, elementwise_affine=True) -> None: - super().__init__() - self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward(self, x): - x = rearrange(x, "b c d h w -> b d h w c") - x = self.norm(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x +class LayerNorm(nn.Module): -class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.inject_noise = inject_noise - - if norm_layer == "group_norm": - self.norm1 = nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x - self.non_linearity = nn.SiLU() - self.conv1 = make_conv_nd( - dims, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) - if inject_noise: - self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) - if norm_layer == "group_norm": - self.norm2 = nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) - self.dropout = torch.nn.Dropout(dropout) + self.norm3 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) if in_channels != out_channels else nn.Identity() - self.conv2 = make_conv_nd( - dims, - out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) + self.timestep_conditioning = timestep_conditioning - if inject_noise: - self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) + def _feed_spatial_noise(self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype - self.norm3 = ( - LayerNorm(in_channels, eps=eps, elementwise_affine=True) - if in_channels != out_channels - else nn.Identity() - ) + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise - self.timestep_conditioning = timestep_conditioning + return hidden_states - if timestep_conditioning: - self.scale_shift_table = nn.Parameter( - torch.randn(4, in_channels) / in_channels**0.5 - ) + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] - def _feed_spatial_noise( - self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor - ) -> torch.FloatTensor: - spatial_shape = hidden_states.shape[-2:] - device = hidden_states.device - dtype = hidden_states.dtype - - # similar to the "explicit noise inputs" method in style-gan - spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] - scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] - hidden_states = hidden_states + scaled_noise - - return hidden_states - - def forward( - self, - input_tensor: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - hidden_states = input_tensor - batch_size = hidden_states.shape[0] - - hidden_states = self.norm1(hidden_states).to(torch.bfloat16) - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - ada_values = self.scale_shift_table[ - None, ..., None, None, None - ] + timestep.reshape( - batch_size, - 4, - -1, - timestep.shape[-3], - timestep.shape[-2], - timestep.shape[-1], - ) - shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + hidden_states = self.norm1(hidden_states).to(torch.bfloat16) + if self.timestep_conditioning: + assert timestep is not None, "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[None, ..., None, None, None] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) - hidden_states = hidden_states * (1 + scale1) + shift1 + hidden_states = hidden_states * (1 + scale1) + shift1 - hidden_states = self.non_linearity(hidden_states) + hidden_states = self.non_linearity(hidden_states) - hidden_states = self.conv1(hidden_states, causal=causal) + hidden_states = self.conv1(hidden_states, causal=causal) - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 - ) + if self.inject_noise: + hidden_states = self._feed_spatial_noise(hidden_states, self.per_channel_scale1) - hidden_states = self.norm2(hidden_states).to(torch.bfloat16) + hidden_states = self.norm2(hidden_states).to(torch.bfloat16) - if self.timestep_conditioning: - hidden_states = hidden_states * (1 + scale2) + shift2 + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 - hidden_states = self.non_linearity(hidden_states) + hidden_states = self.non_linearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, causal=causal) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 - ) + if self.inject_noise: + hidden_states = self._feed_spatial_noise(hidden_states, self.per_channel_scale2) - input_tensor = self.norm3(input_tensor).to(torch.bfloat16) + input_tensor = self.norm3(input_tensor).to(torch.bfloat16) - batch_size = input_tensor.shape[0] + batch_size = input_tensor.shape[0] - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + output_tensor = input_tensor + hidden_states - return output_tensor + return output_tensor def patchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - + if patch_size_hw == 1 and patch_size_t == 1: return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x def unpatchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x + if patch_size_hw == 1 and patch_size_t == 1: + return x - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) - return x + return x def create_video_autoencoder_demo_config( latent_channels: int = 64, ): - encoder_blocks = [ - ("res_x", {"num_layers": 2}), - ("compress_space_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ("compress_time_res", {"multiplier": 2}), - ("res_x", {"num_layers": 1}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 1}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 1}), - ] - decoder_blocks = [ - ("res_x", {"num_layers": 2, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 2, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 2, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 2, "inject_noise": False}), - ] - return { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "encoder_blocks": encoder_blocks, - "decoder_blocks": decoder_blocks, - "latent_channels": latent_channels, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, - "timestep_conditioning": True, - "spatial_padding_mode": "replicate", - } + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } def test_vae_patchify_unpatchify(): - import torch + import torch - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_demo_config() + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = CausalVideoAutoencoder.from_config(config) + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) - print(video_autoencoder) - video_autoencoder.eval() - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 17, 64, 64) + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") - timestep = torch.ones(input_videos.shape[0]) * 0.1 - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape, timestep=timestep - ).sample + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape, timestep=timestep).sample - print(f"reconstructed shape={reconstructed_videos.shape}") + print(f"reconstructed shape={reconstructed_videos.shape}") - # Validate that single image gets treated the same way as first frame - input_image = input_videos[:, :, :1, :, :] - image_latent = video_autoencoder.encode(input_image).latent_dist.mode() - _ = video_autoencoder.decode( - image_latent, target_shape=image_latent.shape, timestep=timestep - ).sample + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode(image_latent, target_shape=image_latent.shape, timestep=timestep).sample - first_frame_latent = latent[:, :, :1, :, :] + first_frame_latent = latent[:, :, :1, :, :] - assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) - # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - # Perform backward pass - loss.backward() + # Perform backward pass + loss.backward() - print(f"Demo completed with loss: {loss.item()}") + print(f"Demo completed with loss: {loss.item()}") # Ensure to call the demo function to execute the forward and backward pass if __name__ == "__main__": - demo_video_autoencoder_forward_backward() + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py index 718c69bef..dd4aba0ca 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -2,8 +2,8 @@ import torch -from ltx_video.models.autoencoders.dual_conv3d import DualConv3d -from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d +from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d +from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d def make_conv_nd( @@ -20,56 +20,56 @@ def make_conv_nd( spatial_padding_mode="zeros", temporal_padding_mode="zeros", ): - if not (spatial_padding_mode == temporal_padding_mode or causal): - raise NotImplementedError("spatial and temporal padding modes must be equal") - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == 3: - if causal: - return CausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - spatial_padding_mode=spatial_padding_mode, - ) - return torch.nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == (2, 1): - return DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=bias, - padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") def make_linear_nd( @@ -78,13 +78,9 @@ def make_linear_nd( out_channels: int, bias=True, ): - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") + if dims == 2: + return torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py index dcf889296..1487a539d 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -8,210 +8,201 @@ class DualConv3d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - dilation: Union[int, Tuple[int, int, int]] = 1, - groups=1, - bias=True, - padding_mode="zeros", - ): - super(DualConv3d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.padding_mode = padding_mode - # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size, kernel_size) - if kernel_size == (1, 1, 1): - raise ValueError( - "kernel_size must be greater than 1. Use make_linear_nd instead." - ) - if isinstance(stride, int): - stride = (stride, stride, stride) - if isinstance(padding, int): - padding = (padding, padding, padding) - if isinstance(dilation, int): - dilation = (dilation, dilation, dilation) - - # Set parameters for convolutions - self.groups = groups - self.bias = bias - - # Define the size of the channels after the first convolution - intermediate_channels = ( - out_channels if in_channels < out_channels else in_channels - ) - # Define parameters for the first convolution - self.weight1 = nn.Parameter( - torch.Tensor( - intermediate_channels, - in_channels // groups, - 1, - kernel_size[1], - kernel_size[2], - ) - ) - self.stride1 = (1, stride[1], stride[2]) - self.padding1 = (0, padding[1], padding[2]) - self.dilation1 = (1, dilation[1], dilation[2]) - if bias: - self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) - else: - self.register_parameter("bias1", None) - - # Define parameters for the second convolution - self.weight2 = nn.Parameter( - torch.Tensor( - out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 - ) - ) - self.stride2 = (stride[0], 1, 1) - self.padding2 = (padding[0], 0, 0) - self.dilation2 = (dilation[0], 1, 1) - if bias: - self.bias2 = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter("bias2", None) - - # Initialize weights and biases - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) - nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) - if self.bias: - fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) - bound1 = 1 / math.sqrt(fan_in1) - nn.init.uniform_(self.bias1, -bound1, bound1) - fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) - bound2 = 1 / math.sqrt(fan_in2) - nn.init.uniform_(self.bias2, -bound2, bound2) - - def forward(self, x, use_conv3d=False, skip_time_conv=False): - if use_conv3d: - return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) - else: - return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) - - def forward_with_3d(self, x, skip_time_conv): - # First convolution - x = F.conv3d( - x, - self.weight1, - self.bias1, - self.stride1, - self.padding1, - self.dilation1, - self.groups, - padding_mode=self.padding_mode, + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) - if skip_time_conv: - return x - - # Second convolution - x = F.conv3d( - x, - self.weight2, - self.bias2, - self.stride2, - self.padding2, - self.dilation2, - self.groups, - padding_mode=self.padding_mode, - ) + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) - return x - - def forward_with_2d(self, x, skip_time_conv): - b, c, d, h, w = x.shape - - # First 2D convolution - x = rearrange(x, "b c d h w -> (b d) c h w") - # Squeeze the depth dimension out of weight1 since it's 1 - weight1 = self.weight1.squeeze(2) - # Select stride, padding, and dilation for the 2D convolution - stride1 = (self.stride1[1], self.stride1[2]) - padding1 = (self.padding1[1], self.padding1[2]) - dilation1 = (self.dilation1[1], self.dilation1[2]) - x = F.conv2d( - x, - weight1, - self.bias1, - stride1, - padding1, - dilation1, - self.groups, - padding_mode=self.padding_mode, - ) + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) - _, _, h, w = x.shape - - if skip_time_conv: - x = rearrange(x, "(b d) c h w -> b c d h w", b=b) - return x - - # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) - - # Reshape weight2 to match the expected dimensions for conv1d - weight2 = self.weight2.squeeze(-1).squeeze(-1) - # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution - stride2 = self.stride2[0] - padding2 = self.padding2[0] - dilation2 = self.dilation2[0] - x = F.conv1d( - x, - weight2, - self.bias2, - stride2, - padding2, - dilation2, - self.groups, - padding_mode=self.padding_mode, - ) - x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) - return x + return x - @property - def weight(self): - return self.weight2 + @property + def weight(self): + return self.weight2 def test_dual_conv3d_consistency(): - # Initialize parameters - in_channels = 3 - out_channels = 5 - kernel_size = (3, 3, 3) - stride = (2, 2, 2) - padding = (1, 1, 1) - - # Create an instance of the DualConv3d class - dual_conv3d = DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=True, - ) - - # Example input tensor - test_input = torch.randn(1, 3, 10, 10, 10) - - # Perform forward passes with both 3D and 2D settings - output_conv3d = dual_conv3d(test_input, use_conv3d=True) - output_2d = dual_conv3d(test_input, use_conv3d=False) - - # Assert that the outputs from both methods are sufficiently close - assert torch.allclose( - output_conv3d, output_2d, atol=1e-6 - ), "Outputs are not consistent between 3D and 2D convolutions." + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose(output_conv3d, output_2d, atol=1e-6), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py index 4a76bc21d..3133b8a49 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -9,195 +9,186 @@ from diffusers import ConfigMixin, ModelMixin from safetensors.torch import safe_open -from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND class ResBlock(nn.Module): - def __init__( - self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 - ): - super().__init__() - if mid_channels is None: - mid_channels = channels - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) - self.norm1 = nn.GroupNorm(32, mid_channels) - self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) - self.norm2 = nn.GroupNorm(32, channels) - self.activation = nn.SiLU() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = self.activation(x) - x = self.conv2(x) - x = self.norm2(x) - x = self.activation(x + residual) - return x + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x class LatentUpsampler(ModelMixin, ConfigMixin): - """ - Model to spatially upsample VAE latents. - - Args: - in_channels (`int`): Number of channels in the input latent - mid_channels (`int`): Number of channels in the middle layers - num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) - dims (`int`): Number of dimensions for convolutions (2 or 3) - spatial_upsample (`bool`): Whether to spatially upsample the latent - temporal_upsample (`bool`): Whether to temporally upsample the latent - """ - - def __init__( - self, - in_channels: int = 128, - mid_channels: int = 512, - num_blocks_per_stage: int = 4, - dims: int = 3, - spatial_upsample: bool = True, - temporal_upsample: bool = False, - ): - super().__init__() - - self.in_channels = in_channels - self.mid_channels = mid_channels - self.num_blocks_per_stage = num_blocks_per_stage - self.dims = dims - self.spatial_upsample = spatial_upsample - self.temporal_upsample = temporal_upsample - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) - self.initial_norm = nn.GroupNorm(32, mid_channels) - self.initial_activation = nn.SiLU() - - self.res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - if spatial_upsample and temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(3), - ) - elif spatial_upsample: - self.upsampler = nn.Sequential( - nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(2), - ) - elif temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(1), - ) - else: - raise ValueError( - "Either spatial_upsample or temporal_upsample must be True" - ) - - self.post_upsample_res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, latent: torch.Tensor) -> torch.Tensor: - b, c, f, h, w = latent.shape - - if self.dims == 2: - x = rearrange(latent, "b c f h w -> (b f) c h w") - x = self.initial_conv(x) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - x = self.upsampler(x) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - else: - x = self.initial_conv(latent) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - if self.temporal_upsample: - x = self.upsampler(x) - x = x[:, :, 1:, :, :] - else: - x = rearrange(x, "b c f h w -> (b f) c h w") - x = self.upsampler(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - - return x - - @classmethod - def from_config(cls, config): - return cls( - in_channels=config.get("in_channels", 4), - mid_channels=config.get("mid_channels", 128), - num_blocks_per_stage=config.get("num_blocks_per_stage", 4), - dims=config.get("dims", 2), - spatial_upsample=config.get("spatial_upsample", True), - temporal_upsample=config.get("temporal_upsample", False), - ) - - def config(self): - return { - "_class_name": "LatentUpsampler", - "in_channels": self.in_channels, - "mid_channels": self.mid_channels, - "num_blocks_per_stage": self.num_blocks_per_stage, - "dims": self.dims, - "spatial_upsample": self.spatial_upsample, - "temporal_upsample": self.temporal_upsample, - } - - @classmethod - def from_pretrained( - cls, - pretrained_model_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_path = Path(pretrained_model_path) - if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( - ".safetensors" - ): - state_dict = {} - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - config = json.loads(metadata["config"]) - with torch.device("meta"): - latent_upsampler = LatentUpsampler.from_config(config) - latent_upsampler.load_state_dict(state_dict, assign=True) - return latent_upsampler + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(".safetensors"): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler if __name__ == "__main__": - latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) - print(latent_upsampler) - total_params = sum(p.numel() for p in latent_upsampler.parameters()) - print(f"Total number of parameters: {total_params:,}") - latent = torch.randn(1, 128, 9, 16, 16) - upsampled_latent = latent_upsampler(latent) - print(f"Upsampled latent shape: {upsampled_latent.shape}") + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py index 9bc3ea60e..0d4277e06 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -3,10 +3,11 @@ class PixelNorm(nn.Module): - def __init__(self, dim=1, eps=1e-8): - super(PixelNorm, self).__init__() - self.dim = dim - self.eps = eps - def forward(self, x): - return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py index 4e79ae284..dae539d26 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -3,31 +3,32 @@ class PixelShuffleND(nn.Module): - def __init__(self, dims, upscale_factors=(2, 2, 2)): - super().__init__() - assert dims in [1, 2, 3], "dims must be 1, 2, or 3" - self.dims = dims - self.upscale_factors = upscale_factors - def forward(self, x): - if self.dims == 3: - return rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - p3=self.upscale_factors[2], - ) - elif self.dims == 2: - return rearrange( - x, - "b (c p1 p2) h w -> b c (h p1) (w p2)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - ) - elif self.dims == 1: - return rearrange( - x, - "b (c p1) f h w -> b c (f p1) h w", - p1=self.upscale_factors[0], - ) + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py index 5b22217c1..1a05b675c 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -10,371 +10,335 @@ DiagonalGaussianDistribution, ) from diffusers.models.modeling_outputs import AutoencoderKLOutput -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd class AutoencoderKLWrapper(ModelMixin, ConfigMixin): - """Variational Autoencoder (VAE) model with KL loss. + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd(quant_dims, 2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = make_conv_nd(quant_dims, latent_channels, latent_channels, 1) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert z_sample_size % 8 == 0 or z_sample_size == 1, f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False - VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. - This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (1 - z / blend_extent) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = sizes + [z.shape[2] - sum(sizes)] if z.shape[2] - sum(sizes) > 0 else sizes + tiles = z.split(sizes, dim=2) + moments_tiles = [ + (self._hw_tiled_encode(z_tile, return_dict) if self.use_hw_tiling else self._encode(z_tile)) for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = self._hw_tiled_encode(z, return_dict) if self.use_hw_tiling else self._encode(z) + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t * 2 ** (len(self.encoder.down_blocks) - 1 - math.sqrt(self.encoder.patch_size)) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode( + z, target_shape + ) # z(1, 128, 1, 6, 8), need to be torch.Size([]) target (1, 3, 8, 192, 256), should be type Tensor + if self.use_hw_tiling + else self._decode( + z, target_shape=target_shape, timestep=timestep + ) # Tensor( 0.05) size torch.Size([]) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" Args: - encoder (`nn.Module`): - Encoder module. - decoder (`nn.Module`): - Decoder module. - latent_channels (`int`, *optional*, defaults to 4): - Number of latent channels. + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. """ - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - latent_channels: int = 4, - dims: int = 2, - sample_size=512, - use_quant_conv: bool = True, - normalize_latent_channels: bool = False, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = encoder - self.use_quant_conv = use_quant_conv - self.normalize_latent_channels = normalize_latent_channels - - # pass init params to Decoder - quant_dims = 2 if dims == 2 else 3 - self.decoder = decoder - if use_quant_conv: - self.quant_conv = make_conv_nd( - quant_dims, 2 * latent_channels, 2 * latent_channels, 1 - ) - self.post_quant_conv = make_conv_nd( - quant_dims, latent_channels, latent_channels, 1 - ) - else: - self.quant_conv = nn.Identity() - self.post_quant_conv = nn.Identity() - - if normalize_latent_channels: - if dims == 2: - self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.Identity() - self.use_z_tiling = False - self.use_hw_tiling = False - self.dims = dims - self.z_sample_size = 1 - - self.decoder_params = inspect.signature(self.decoder.forward).parameters - - # only relevant if vae tiling is enabled - self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) - - def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): - self.tile_sample_min_size = sample_size - num_blocks = len(self.encoder.down_blocks) - self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) - self.tile_overlap_factor = overlap_factor - - def enable_z_tiling(self, z_sample_size: int = 8): - r""" - Enable tiling during VAE decoding. - - When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_z_tiling = z_sample_size > 1 - self.z_sample_size = z_sample_size - assert ( - z_sample_size % 8 == 0 or z_sample_size == 1 - ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." - - def disable_z_tiling(self): - r""" - Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_z_tiling = False - - def enable_hw_tiling(self): - r""" - Enable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = True - - def disable_hw_tiling(self): - r""" - Disable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = False - - def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent - - # Split the image into 512x512 tiles and encode them separately. - rows = [] - for i in range(0, x.shape[3], overlap_size): - row = [] - for j in range(0, x.shape[4], overlap_size): - tile = x[ - :, - :, - :, - i : i + self.tile_sample_min_size, - j : j + self.tile_sample_min_size, - ] - tile = self.encoder(tile) - tile = self.quant_conv(tile) - row.append(tile) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - moments = torch.cat(result_rows, dim=3) - return moments - - def blend_z( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[2], b.shape[2], blend_extent) - for z in range(blend_extent): - b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( - 1 - z / blend_extent - ) + b[:, :, z, :, :] * (z / blend_extent) - return b - - def blend_v( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[3], b.shape[3], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( - 1 - y / blend_extent - ) + b[:, :, :, y, :] * (y / blend_extent) - return b - - def blend_h( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[4], b.shape[4], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( - 1 - x / blend_extent - ) + b[:, :, :, :, x] * (x / blend_extent) - return b - - def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent - tile_target_shape = ( - *target_shape[:3], - self.tile_sample_min_size, - self.tile_sample_min_size, - ) - # Split z into overlapping 64x64 tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, z.shape[3], overlap_size): - row = [] - for j in range(0, z.shape[4], overlap_size): - tile = z[ - :, - :, - :, - i : i + self.tile_latent_min_size, - j : j + self.tile_latent_min_size, - ] - tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, target_shape=tile_target_shape) - row.append(decoded) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - return dec - - def encode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - num_splits = z.shape[2] // self.z_sample_size - sizes = [self.z_sample_size] * num_splits - sizes = ( - sizes + [z.shape[2] - sum(sizes)] - if z.shape[2] - sum(sizes) > 0 - else sizes - ) - tiles = z.split(sizes, dim=2) - moments_tiles = [ - ( - self._hw_tiled_encode(z_tile, return_dict) - if self.use_hw_tiling - else self._encode(z_tile) - ) - for z_tile in tiles - ] - moments = torch.cat(moments_tiles, dim=2) - - else: - moments = ( - self._hw_tiled_encode(z, return_dict) - if self.use_hw_tiling - else self._encode(z) - ) - - posterior = DiagonalGaussianDistribution(moments) - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - _, c, _, _, _ = z.shape - z = torch.cat( - [ - self.latent_norm_out(z[:, : c // 2, :, :, :]), - z[:, c // 2 :, :, :, :], - ], - dim=1, - ) - elif isinstance(self.latent_norm_out, nn.BatchNorm2d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) - running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) - eps = self.latent_norm_out.eps - - z = z * torch.sqrt(running_var + eps) + running_mean - elif isinstance(self.latent_norm_out, nn.BatchNorm3d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - moments = self._normalize_latent_channels(moments) - return moments - - def _decode( - self, - z: torch.FloatTensor, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - z = self._unnormalize_latent_channels(z) - z = self.post_quant_conv(z) - if "timestep" in self.decoder_params: - dec = self.decoder(z, target_shape=target_shape, timestep=timestep) - else: - dec = self.decoder(z, target_shape=target_shape) - return dec - - def decode( - self, - z: torch.FloatTensor, - return_dict: bool = True, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - assert target_shape is not None, "target_shape must be provided for decoding" - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - reduction_factor = int( - self.encoder.patch_size_t - * 2 - ** ( - len(self.encoder.down_blocks) - - 1 - - math.sqrt(self.encoder.patch_size) - ) - ) - split_size = self.z_sample_size // reduction_factor - num_splits = z.shape[2] // split_size - - # copy target shape, and divide frame dimension (=2) by the context size - target_shape_split = list(target_shape) - target_shape_split[2] = target_shape[2] // num_splits - - decoded_tiles = [ - ( - self._hw_tiled_decode(z_tile, target_shape_split) - if self.use_hw_tiling - else self._decode(z_tile, target_shape=target_shape_split) - ) - for z_tile in torch.tensor_split(z, num_splits, dim=2) - ] - decoded = torch.cat(decoded_tiles, dim=2) - else: - decoded = ( - self._hw_tiled_decode(z, target_shape) - if self.use_hw_tiling - else self._decode(z, target_shape=target_shape, timestep=timestep) - ) - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`DecoderOutput`] instead of a plain tuple. - generator (`torch.Generator`, *optional*): - Generator used to sample from the posterior. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z, target_shape=sample.shape).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py index 5a0aeeccf..e24f16fd2 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -4,7 +4,6 @@ from einops import rearrange from torch import Tensor - from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) @@ -14,9 +13,9 @@ ) try: - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm except ImportError: - xm = None + xm = None def vae_encode( @@ -25,73 +24,68 @@ def vae_encode( split_size: int = 1, vae_per_channel_normalize=False, ) -> Tensor: - """ - Encodes media items (images or videos) into latent representations using a specified VAE model. - The function supports processing batches of images or video frames and can handle the processing - in smaller sub-batches if needed. - - Args: - media_items (Tensor): A torch Tensor containing the media items to encode. The expected - shape is (batch_size, channels, height, width) for images or (batch_size, channels, - frames, height, width) for videos. - vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, - pre-configured and loaded with the appropriate model weights. - split_size (int, optional): The number of sub-batches to split the input batch into for encoding. - If set to more than 1, the input media items are processed in smaller batches according to - this value. Defaults to 1, which processes all items in a single batch. - - Returns: - Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted - to match the input shape, scaled by the model's configuration. - - Examples: - >>> import torch - >>> from diffusers import AutoencoderKL - >>> vae = AutoencoderKL.from_pretrained('your-model-name') - >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. - >>> latents = vae_encode(images, vae) - >>> print(latents.shape) # Output shape will depend on the model's latent configuration. - - Note: - In case of a video, the function encodes the media item frame-by frame. - """ - is_video_shaped = media_items.dim() == 5 - batch_size, channels = media_items.shape[0:2] - - if channels != 3: - raise ValueError(f"Expects tensors with 3 channels, got {channels}.") - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - media_items = rearrange(media_items, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(media_items) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(media_items) // split_size - # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] - latents = [] - if media_items.device.type == "xla": - xm.mark_step() - for image_batch in media_items.split(encode_bs): - latents.append(vae.encode(image_batch).latent_dist.sample()) - if media_items.device.type == "xla": - xm.mark_step() - latents = torch.cat(latents, dim=0) - else: - latents = vae.encode(media_items).latent_dist.sample() - - latents = normalize_latents(latents, vae, vae_per_channel_normalize) - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) - return latents - - -def vae_decode( + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + dist = vae.encode(media_items).latent_dist + latents = dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( # this function needs latents to be in tensor form latents: Tensor, vae: AutoencoderKL, is_video: bool = True, @@ -99,36 +93,26 @@ def vae_decode( vae_per_channel_normalize=False, timestep=None, ) -> Tensor: - is_video_shaped = latents.dim() == 5 - batch_size = latents.shape[0] - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(latents) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(latents) // split_size - image_batch = [ - _run_decoder( - latent_batch, vae, is_video, vae_per_channel_normalize, timestep - ) - for latent_batch in latents.split(encode_bs) - ] - images = torch.cat(image_batch, dim=0) - else: - images = _run_decoder( - latents, vae, is_video, vae_per_channel_normalize, timestep - ) - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) - return images + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize, timestep) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize, timestep) + + if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images def _run_decoder( @@ -138,110 +122,88 @@ def _run_decoder( vae_per_channel_normalize=False, timestep=None, ) -> Tensor: - if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): - *_, fl, hl, wl = latents.shape - temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) - latents = latents.to(vae.dtype) - vae_decode_kwargs = {} - if timestep is not None: - vae_decode_kwargs["timestep"] = timestep - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - target_shape=( - 1, - 3, - fl * temporal_scale if is_video else 1, - hl * spatial_scale, - wl * spatial_scale, - ), - **vae_decode_kwargs, - )[0] - else: - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - )[0] - return image + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: - if isinstance(vae, CausalVideoAutoencoder): - spatial = vae.spatial_downscale_factor - temporal = vae.temporal_downscale_factor - else: - down_blocks = len( - [ - block - for block in vae.encoder.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - spatial = vae.config.patch_size * 2**down_blocks - temporal = ( - vae.config.patch_size_t * 2**down_blocks - if isinstance(vae, VideoAutoencoder) - else 1 - ) - - return (temporal, spatial, spatial) - - -def latent_to_pixel_coords( - latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False -) -> Tensor: - """ - Converts latent coordinates to pixel coordinates by scaling them according to the VAE's - configuration. - - Args: - latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] - containing the latent corner coordinates of each token. - vae (AutoencoderKL): The VAE model - causal_fix (bool): Whether to take into account the different temporal scale - of the first frame. Default = False for backwards compatibility. - Returns: - Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. - """ - - scale_factors = get_vae_size_scale_factor(vae) - causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix - pixel_coords = latent_to_pixel_coords_from_factors( - latent_coords, scale_factors, causal_fix - ) - return pixel_coords - - -def latent_to_pixel_coords_from_factors( - latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False -) -> Tensor: - pixel_coords = ( - latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] - ) - if causal_fix: - # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) - return pixel_coords - - -def normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) - / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents * vae.config.scaling_factor - ) - - -def un_normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents / vae.config.scaling_factor - ) + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)]) + spatial = vae.config.patch_size * 2**down_blocks + temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1 + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords(latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors(latent_coords, scale_factors, causal_fix) + return pixel_coords + + +def latent_to_pixel_coords_from_factors(latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False) -> Tensor: + pixel_coords = latent_coords * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py index 45fb33280..dfb7512ff 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py @@ -10,79 +10,88 @@ # remove weight attribute to avoid error in JittableModule # in the future, this will be fixed in ltxv public repo -delattr(causal_conv3d.CausalConv3d, 'weight') +delattr(causal_conv3d.CausalConv3d, "weight") -class TorchaxCausalVideoAutoencoder(interop.JittableModule): - def __init__(self, vae: CausalVideoAutoencoder): - super().__init__(vae, extra_jit_args=dict(static_argnames=['split_size', 'vae_per_channel_normalize'])) - def encode(self, media_items: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True) -> jax.Array: - if media_items.ndim != 5: - raise ValueError( - f"Expected media_items to have 5 dimensions (batch, channels, frames, height, width), but got {media_items.ndim} dimensions." - ) - num_frames = media_items.shape[2] - if (num_frames - 1) % 8 != 0: - raise ValueError( - f"Expected media_items to have a number of frames that is 1 + 8 * k for some integer k, but got {num_frames} frames." - ) - with default_env(): - media_items = interop.torch_view(media_items) +class TorchaxCausalVideoAutoencoder(interop.JittableModule): - output = self.functional_call( - self._vae_encoder_inner, - params=self.params, - buffers=self.buffers, - media_items=media_items, - split_size=split_size, - vae_per_channel_normalize=vae_per_channel_normalize, - ) + def __init__(self, vae: CausalVideoAutoencoder): + super().__init__(vae, extra_jit_args=dict(static_argnames=["split_size", "vae_per_channel_normalize"])) + def encode(self, media_items: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True) -> jax.Array: + if media_items.ndim != 5: + raise ValueError( + f"Expected media_items to have 5 dimensions (batch, channels, frames, height, width), but got {media_items.ndim} dimensions." + ) + num_frames = media_items.shape[2] + if (num_frames - 1) % 8 != 0: + raise ValueError( + f"Expected media_items to have a number of frames that is 1 + 8 * k for some integer k, but got {num_frames} frames." + ) + with default_env(): + media_items = interop.torch_view(media_items) - return interop.jax_view(output) + output = self.functional_call( + self._vae_encoder_inner, + params=self.params, + buffers=self.buffers, + media_items=media_items, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) - def decode(self, latents: jax.Array, timestep: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True, is_video: bool = True) -> jax.Array: - with default_env(): - latents = interop.torch_view(latents) - timestep = interop.torch_view(timestep) - output = self.functional_call( - self._vae_decoder_inner, - params=self.params, - buffers=self.buffers, - latents=latents, - timestep=timestep, - split_size=split_size, - vae_per_channel_normalize=vae_per_channel_normalize, - is_video=is_video, - ) + return interop.jax_view(output) - return interop.jax_view(output) + def decode( + self, + latents: jax.Array, + timestep: jax.Array, + split_size: int = 1, + vae_per_channel_normalize: bool = True, + is_video: bool = True, + ) -> jax.Array: + with default_env(): + latents = interop.torch_view(latents) + timestep = interop.torch_view(timestep) + output = self.functional_call( + self._vae_decoder_inner, + params=self.params, + buffers=self.buffers, + latents=latents, + timestep=timestep, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + is_video=is_video, + ) + return interop.jax_view(output) - @staticmethod - def _vae_encoder_inner(model, media_items, split_size, vae_per_channel_normalize): - return vae_encode( - media_items=media_items, - vae=model, - split_size=split_size, - vae_per_channel_normalize=vae_per_channel_normalize, - ) + @staticmethod + def _vae_encoder_inner(model, media_items, split_size, vae_per_channel_normalize): + return vae_encode( + media_items=media_items, + vae=model, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + ) - @staticmethod - def _vae_decoder_inner(model, latents, timestep, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize: bool = False): - return vae_decode( - latents=latents, - vae=model, - is_video=is_video, - split_size=split_size, - vae_per_channel_normalize=vae_per_channel_normalize, - timestep=timestep, - ) + @staticmethod + def _vae_decoder_inner( + model, latents, timestep, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize: bool = False + ): + return vae_decode( + latents=latents, + vae=model, + is_video=is_video, + split_size=split_size, + vae_per_channel_normalize=vae_per_channel_normalize, + timestep=timestep, + ) - @staticmethod - def normalize_img(image): - return (image - 128) / 128 + @staticmethod + def normalize_img(image): + return (image - 128) / 128 - @staticmethod - def denormalize_img(image): - return (image * 128 + 128).clip(0, 255) \ No newline at end of file + @staticmethod + def denormalize_img(image): + return (image * 128 + 128).clip(0, 255) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py index 3c7926c1d..77ae60d4e 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -12,1034 +12,953 @@ from diffusers.utils import logging from ltx_video.utils.torch_utils import Identity -from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from ltx_video.models.autoencoders.pixel_norm import PixelNorm -from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper logger = logging.get_logger(__name__) class VideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - video_vae = cls.from_config(config) - video_vae.to(kwargs["torch_dtype"]) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - ckpt_state_dict = torch.load(model_local_path) - video_vae.load_state_dict(ckpt_state_dict) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) - video_vae.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) - - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "VideoAutoencoder" - ), "config must have _class_name=VideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - - if use_quant_conv and latent_log_var == "uniform": - raise ValueError("uniform latent_log_var requires use_quant_conv=False") - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - dims = config["dims"] - return VideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="VideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels - // (self.encoder.patch_size_t * self.encoder.patch_size**2), - out_channels=self.decoder.conv_out.out_channels - // (self.decoder.patch_size_t * self.decoder.patch_size**2), - latent_channels=self.decoder.conv_in.in_channels, - block_out_channels=[ - self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels - for i in range(len(self.encoder.down_blocks)) - ], - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - patch_size_t=self.encoder.patch_size_t, - add_channel_padding=self.encoder.add_channel_padding, - ) - - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def downscale_factor(self): - return self.encoder.downsample_factor - - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - model_keys = set(name for name, _ in self.named_parameters()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - - converted_state_dict = {} - for key, value in state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - if "norm" in key and key not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json" + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)} + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none") + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - if add_channel_padding: - in_channels = in_channels * self.patch_size**3 - else: - in_channels = in_channels * self.patch_size_t * self.patch_size**2 - self.in_channels = in_channels - output_channel = block_out_channels[0] - - self.conv_in = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - padding=1, - ) - - self.down_blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels)): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - num_layers=self.layers_per_block, - add_downsample=not is_final_block and 2**i >= patch_size, - resnet_eps=1e-6, - downsample_padding=0, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.down_blocks.append(down_block) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], - num_groups=norm_num_groups, - eps=1e-6, - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, block_out_channels[-1], conv_out_channels, 3, padding=1 - ) - - self.gradient_checkpointing = False - - @property - def downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - * self.patch_size - ) - - def forward( - self, sample: torch.FloatTensor, return_features=False - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - downsample_in_time = sample.shape[2] != 1 - - # patchify - patch_size_t = self.patch_size_t if downsample_in_time else 1 - sample = patchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) - - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - if return_features: - features = [] - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)( - sample, downsample_in_time=downsample_in_time - ) - if return_features: - features.append(sample) - - sample = checkpoint_fn(self.mid_block)(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - - if return_features: - features.append(sample[:, : self.latent_channels, ...]) - return sample, features - return sample - + @property + def downscale_factor(self): + return self.encoder.downsample_factor -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - """ + def to_json_string(self) -> str: + import json - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - if add_channel_padding: - out_channels = out_channels * self.patch_size**3 - else: - out_channels = out_channels * self.patch_size_t * self.patch_size**2 - self.out_channels = out_channels - - self.conv_in = make_conv_nd( - dims, - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - dims=dims, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block - and 2 ** (len(block_out_channels) - i - 1) > patch_size, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.up_blocks.append(up_block) + return json.dumps(self.config.__dict__) - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, block_out_channels[0], out_channels, 3, padding=1 - ) + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } - self.gradient_checkpointing = False + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) - def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - upsample_in_time = sample.shape[2] < target_shape[2] + if "norm" in key and key not in model_keys: + logger.info(f"Removing key {key} from state_dict as it is not present in the model") + continue - sample = self.conv_in(sample) + converted_state_dict[key] = value - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + super().load_state_dict(converted_state_dict, strict=strict) - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer - sample = checkpoint_fn(self.mid_block)(sample) - sample = sample.to(upscale_dtype) - for up_block in self.up_blocks: - sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return 2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)]) * self.patch_size + + def forward(self, sample: torch.FloatTensor, return_features=False) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - # un-patchify - patch_size_t = self.patch_size_t if upsample_in_time else 1 - sample = unpatchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels - return sample + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample class DownEncoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_downsample: bool = True, - downsample_padding: int = 1, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - if add_downsample: - self.downsample = Downsample3D( - dims, - out_channels, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsample = Identity() + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() - def forward( - self, hidden_states: torch.FloatTensor, downsample_in_time - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) + def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) - hidden_states = self.downsample( - hidden_states, downsample_in_time=downsample_in_time - ) + hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time) - return hidden_states + return hidden_states class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) - """ + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - return hidden_states + return hidden_states class UpDecoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_upsample: bool = True, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=input_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - if add_upsample: - self.upsample = Upsample3D( - dims=dims, channels=out_channels, out_channels=out_channels - ) - else: - self.upsample = Identity() + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels) + else: + self.upsample = Identity() - self.resolution_idx = resolution_idx + self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.FloatTensor, upsample_in_time=True - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) + def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) - hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) - return hidden_states + return hidden_states class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ + r""" + A Resnet block. - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - if norm_layer == "group_norm": - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ - self.non_linearity = nn.SiLU() + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut - self.conv1 = make_conv_nd( - dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() - if norm_layer == "group_norm": - self.norm2 = torch.nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() + self.non_linearity = nn.SiLU() - self.dropout = torch.nn.Dropout(dropout) + self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.conv2 = make_conv_nd( - dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) + self.dropout = torch.nn.Dropout(dropout) - def forward( - self, - input_tensor: torch.FloatTensor, - ) -> torch.FloatTensor: - hidden_states = input_tensor + self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1) - hidden_states = self.norm1(hidden_states) + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) - hidden_states = self.non_linearity(hidden_states) + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor - hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) - hidden_states = self.norm2(hidden_states) + hidden_states = self.non_linearity(hidden_states) - hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv1(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.norm2(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.non_linearity(hidden_states) - input_tensor = self.conv_shortcut(input_tensor) + hidden_states = self.dropout(hidden_states) - output_tensor = input_tensor + hidden_states + hidden_states = self.conv2(hidden_states) - return output_tensor + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = input_tensor + hidden_states -class Downsample3D(nn.Module): - def __init__( - self, - dims, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - padding: int = 1, - ): - super().__init__() - stride: int = 2 - self.padding = padding - self.in_channels = in_channels - self.dims = dims - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - - def forward(self, x, downsample_in_time=True): - conv = self.conv - if self.padding == 0: - if self.dims == 2: - padding = (0, 1, 0, 1) - else: - padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) - - x = functional.pad(x, padding, mode="constant", value=0) - - if self.dims == (2, 1) and not downsample_in_time: - return conv(x, skip_time_conv=True) - - return conv(x) + return output_tensor -class Upsample3D(nn.Module): - """ - An upsampling layer for 3D tensors of shape (B, C, D, H, W). +class Downsample3D(nn.Module): - :param channels: channels in the inputs and outputs. - """ + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) - def __init__(self, dims, channels, out_channels=None): - super().__init__() - self.dims = dims - self.channels = channels - self.out_channels = out_channels or channels - self.conv = make_conv_nd( - dims, channels, out_channels, kernel_size=3, padding=1, bias=True - ) - - def forward(self, x, upsample_in_time): - if self.dims == 2: - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - else: - time_scale_factor = 2 if upsample_in_time else 1 - # print("before:", x.shape) - b, c, d, h, w = x.shape - x = rearrange(x, "b c d h w -> (b d) c h w") - # height and width interpolate - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - _, _, h, w = x.shape - if not upsample_in_time and self.dims == (2, 1): - x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) - return self.conv(x, skip_time_conv=True) +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest") + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest") + _, _, h, w = x.shape - # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) - # (b h w) c 1 d - new_d = x.shape[-1] * time_scale_factor - x = functional.interpolate(x, (1, new_d), mode="nearest") - # (b h w) c 1 new_d - x = rearrange( - x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d - ) - # b c d h w + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) - # x = functional.interpolate( - # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - # ) - # print("after:", x.shape) + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d) + # b c d h w - return self.conv(x) + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + return self.conv(x) -def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] - padding_zeros = torch.zeros( - x.shape[0], - channels_to_pad, - x.shape[2], - x.shape[3], - x.shape[4], - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([padding_zeros, x], dim=1) +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) - x = x[:, :channels_to_keep, :, :, :] - - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - + if patch_size_hw == 1 and patch_size_t == 1: return x + if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + def create_video_autoencoder_config( latent_channels: int = 4, ): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [ - 128, - 256, - 512, - 512, - ], # Number of output channels of each encoder / decoder inner block - "patch_size": 1, - } - - return config + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config def create_video_autoencoder_pathify4x4x4_config( latent_channels: int = 4, ): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "latent_log_var": "uniform", - } - - return config + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config def create_video_autoencoder_pathify4x4_config( latent_channels: int = 4, ): - config = { - "_class_name": "VideoAutoencoder", - "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "norm_layer": "pixel_norm", - } + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } - return config + return config def test_vae_patchify_unpatchify(): - import torch + import torch - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_pathify4x4x4_config() + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = VideoAutoencoder.from_config(config) + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) - print(video_autoencoder) + print(video_autoencoder) - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 8, 64, 64) + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape - ).sample + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample - print(f"reconstructed shape={reconstructed_videos.shape}") + print(f"reconstructed shape={reconstructed_videos.shape}") - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - # Perform backward pass - loss.backward() + # Perform backward pass + loss.backward() - print(f"Demo completed with loss: {loss.item()}") + print(f"Demo completed with loss: {loss.item()}") # Ensure to call the demo function to execute the forward and backward pass if __name__ == "__main__": - demo_video_autoencoder_forward_backward() + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 75692b703..a6ac74c71 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -458,7 +458,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py index 2eca32033..83ec62342 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -8,77 +8,75 @@ class Patchifier(ConfigMixin, ABC): - def __init__(self, patch_size: int): - super().__init__() - self._patch_size = (1, patch_size, patch_size) - @abstractmethod - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - raise NotImplementedError("Patchify method not implemented") + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) - @abstractmethod - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - pass + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") - @property - def patch_size(self): - return self._patch_size + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass - def get_latent_coords( - self, latent_num_frames, latent_height, latent_width, batch_size, device - ): - """ - Return a tensor of shape [batch_size, 3, num_patches] containing the - top-left corner latent coordinates of each latent patch. - The tensor is repeated for each batch element. - """ - latent_sample_coords = torch.meshgrid( - torch.arange(0, latent_num_frames, self._patch_size[0], device=device), - torch.arange(0, latent_height, self._patch_size[1], device=device), - torch.arange(0, latent_width, self._patch_size[2], device=device), - ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size - ) - return latent_coords + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords(self, latent_num_frames, latent_height, latent_width, batch_size, device): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange(latent_coords, "b c f h w -> b c (f h w)", b=batch_size) + return latent_coords class SymmetricPatchifier(Patchifier): - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - b, _, f, h, w = latents.shape - latent_coords = self.get_latent_coords(f, h, w, b, latents.device) - latents = rearrange( - latents, - "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", - p1=self._patch_size[0], - p2=self._patch_size[1], - p3=self._patch_size[2], - ) - return latents, latent_coords - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - output_height = output_height // self._patch_size[1] - output_width = output_width // self._patch_size[2] - latents = rearrange( - latents, - "b (f h w) (c p q) -> b c f (h p) (w q)", - h=output_height, - w=output_width, - p=self._patch_size[1], - q=self._patch_size[2], - ) - return latents + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 1c1807fdd..8b12b1d81 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -112,7 +112,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( #noqa: C408 + module_init_kwargs=dict( # noqa: C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, diff --git a/src/maxdiffusion/models/ltx_video/utils/__init__.py b/src/maxdiffusion/models/ltx_video/utils/__init__.py index cb4a6b9ce..285b6e81c 100644 --- a/src/maxdiffusion/models/ltx_video/utils/__init__.py +++ b/src/maxdiffusion/models/ltx_video/utils/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # # This implementation is based on the Torch version available at: -# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py index 53c0082d1..832bda051 100644 --- a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -1,13 +1,13 @@ def make_hashable_key(dict_key): - def convert_value(value): - if isinstance(value, list): - return tuple(value) - elif isinstance(value, dict): - return tuple(sorted((k, convert_value(v)) for k, v in value.items())) - else: - return value + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value - return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) DIFFUSERS_SCHEDULER_CONFIG = { diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py index 901051728..20d09f73e 100644 --- a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -45,20 +45,20 @@ def tensor_to_pil(tensor): - # Ensure tensor is in range [-1, 1] - assert tensor.min() >= -1 and tensor.max() <= 1 + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 - # Convert from [-1, 1] to [0, 1] - tensor = (tensor + 1) / 2 + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 - # Rearrange from [C, H, W] to [H, W, C] - tensor = tensor.permute(1, 2, 0) + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) - # Convert to numpy array and then to uint8 range [0, 255] - numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") - # Convert to PIL Image - return Image.fromarray(numpy_image) + # Convert to PIL Image + return Image.fromarray(numpy_image) def generate_cinematic_prompt( @@ -70,52 +70,45 @@ def generate_cinematic_prompt( conditioning_items: Optional[List] = None, max_new_tokens: int = 256, ) -> List[str]: - prompts = [prompt] if isinstance(prompt, str) else prompt - - if conditioning_items is None: - prompts = _generate_t2v_prompt( - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - max_new_tokens, - T2V_CINEMATIC_PROMPT, - ) - else: - if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: - logger.warning( - "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" - ) - return prompts - - first_frame_conditioning_item = conditioning_items[0] - first_frames = _get_first_frames_from_conditioning_item( - first_frame_conditioning_item - ) - - assert len(first_frames) == len( - prompts - ), "Number of conditioning frames must match number of prompts" - - prompts = _generate_i2v_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - first_frames, - max_new_tokens, - I2V_CINEMATIC_PROMPT, - ) - - return prompts + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item(first_frame_conditioning_item) + + assert len(first_frames) == len(prompts), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: - frames_tensor = conditioning_item.media_item - return [ - tensor_to_pil(frames_tensor[i, :, 0, :, :]) - for i in range(frames_tensor.shape[0]) - ] + frames_tensor = conditioning_item.media_item + return [tensor_to_pil(frames_tensor[i, :, 0, :, :]) for i in range(frames_tensor.shape[0])] def _generate_t2v_prompt( @@ -125,27 +118,18 @@ def _generate_t2v_prompt( max_new_tokens: int, system_prompt: str, ) -> List[str]: - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}"}, - ] - for p in prompts - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) + texts = [prompt_enhancer_tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(prompt_enhancer_model.device) + + return _generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens) def _generate_i2v_prompt( @@ -158,31 +142,20 @@ def _generate_i2v_prompt( max_new_tokens: int, system_prompt: str, ) -> List[str]: - image_captions = _generate_image_captions( - image_caption_model, image_caption_processor, first_frames - ) + image_captions = _generate_image_captions(image_caption_model, image_caption_processor, first_frames) - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, - ] - for p, c in zip(prompts, image_captions) - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) + texts = [prompt_enhancer_tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(prompt_enhancer_model.device) + + return _generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens) def _generate_image_captions( @@ -191,36 +164,27 @@ def _generate_image_captions( images: List[Image.Image], system_prompt: str = "", ) -> List[str]: - image_caption_prompts = [system_prompt] * len(images) - inputs = image_caption_processor( - image_caption_prompts, images, return_tensors="pt" - ).to(image_caption_model.device) - - with torch.inference_mode(): - generated_ids = image_caption_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - do_sample=False, - num_beams=3, - ) + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor(image_caption_prompts, images, return_tensors="pt").to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) - return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) def _generate_and_decode_prompts( prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int ) -> List[str]: - with torch.inference_mode(): - outputs = prompt_enhancer_model.generate( - **model_inputs, max_new_tokens=max_new_tokens - ) - generated_ids = [ - output_ids[len(input_ids) :] - for input_ids, output_ids in zip(model_inputs.input_ids, outputs) - ] - decoded_prompts = prompt_enhancer_tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - ) - - return decoded_prompts + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate(**model_inputs, max_new_tokens=max_new_tokens) + generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, outputs)] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py index 30f9016e1..476d38c75 100644 --- a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -2,7 +2,7 @@ class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py index 991b07c36..32b19b167 100644 --- a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -3,23 +3,21 @@ def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - elif dims_to_append == 0: - return x - return x[(...,) + (None,) * dims_to_append] + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] class Identity(nn.Module): - """A placeholder identity operator that is argument-insensitive.""" + """A placeholder identity operator that is argument-insensitive.""" - def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument - super().__init__() + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() - # pylint: disable=unused-argument - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - return x + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index cfb8a72e5..a84dd491c 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -11,34 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import math import os from jax import Array -from datetime import datetime -from pathlib import Path from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F -from diffusers.utils.torch_utils import randn_tensor from maxdiffusion.models.ltx_video.autoencoders.vae_torchax import TorchaxCausalVideoAutoencoder -# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -import yaml -from transformers import (CLIPTokenizer, FlaxCLIPTextModel, - T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) - -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from transformers import (FlaxT5EncoderModel, AutoTokenizer) from torchax import interop from torchax import default_env -import imageio import json import numpy as np import torch -from safetensors import safe_open -from PIL import Image from transformers import ( - T5EncoderModel, - T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, @@ -53,32 +39,24 @@ un_normalize_latents, normalize_latents, ) -# from diffusers.image_processor import VaeImageProcessor from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt from math import e from types import NoneType from typing import Any, Dict import numpy as np -import inspect import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P from typing import Optional, Union, List import torch -from maxdiffusion.checkpointing import checkpointing_utils from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...pyconfig import HyperParameters -# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState -from ...max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations -) +from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os import json @@ -87,1124 +65,751 @@ import pickle -class PickleCheckpointHandler(ocp.CheckpointHandler): - def save(self, directory: str, item, args=None): - os.makedirs(directory, exist_ok=True) - with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: - pickle.dump(item, f) - - def restore(self, directory: str, args=None): - with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: - return pickle.load(f) +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - def structure(self, directory: str): - return {} # not needed for simple pickle-based handling +class LTXVideoPipeline: -def save_tensor_dict(tensor_dict, timestep): + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: TorchaxCausalVideoAutoencoder, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - local_path = os.path.join(base_dir, f"schedulerTest{timestep}") - - try: - torch.save(tensor_dict, local_path) - print(f"Dictionary of tensors saved to: {local_path}") - except Exception as e: - print(f"Error saving dictionary: {e}") - raise - - -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) - print("encoder_attention_segment_ids.shape: ", - encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - - -def prepare_extra_step_kwargs(generator): - extra_step_kwargs = {} - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - -class LTXVideoPipeline: - def __init__( - self, - transformer: Transformer3DModel, - scheduler: FlaxRectifiedFlowMultistepScheduler, - scheduler_state: RectifiedFlowSchedulerState, - vae: TorchaxCausalVideoAutoencoder, - text_encoder, - patchifier, - tokenizer, + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, jax.random.PRNGKey(0), model_config["caption_channels"], eval_only=True + ) + # loading from weight checkpoint + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16) + # in torchax + with default_env(): + torch_vae = torch_vae.to("jax") + jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) + return jax_vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return ( prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - transformer_state: Dict[Any, Any] = None, - transformer_state_shardings: Dict[Any, Any] = NoneType, - ): - self.transformer = transformer - self.devices_array = devices_array - self.mesh = mesh - self.config = config - self.p_run_inference = None - self.transformer_state = transformer_state - self.transformer_state_shardings = transformer_state_shardings - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.vae = vae - self.text_encoder = text_encoder - self.patchifier = patchifier - self.tokenizer = tokenizer - self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model - self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor - self.prompt_enhancer_llm_model = prompt_enhancer_llm_model - self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer - self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( - self.vae - ) - # self.image_processor = VaeImageProcessor( - # vae_scale_factor=self.vae_scale_factor) - - @classmethod - def load_scheduler(cls, ckpt_path, config): - # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - # "Wan-AI/Wan2.1-T2V-14B-Diffusers", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( - # "Lightricks/LTX-Video", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # import pdb; pdb.set_trace() - # scheduler = FlaxRectifiedFlowMultistepScheduler( - # sampler="LinearQuadratic" - # ) - # scheduler_state = scheduler.create_state() - if config.sampler == "from_checkpoint" or not config.sampler: - scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) - else: - scheduler = FlaxRectifiedFlowMultistepScheduler( - sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") - ) - scheduler_state = scheduler.create_state() - - return scheduler, scheduler_state - - @classmethod - def load_transformer(cls, config): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", - "causal_temporal_positioning", "in_channels", "ckpt_path"] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - transformer = Transformer3DModel( - # change this sharding back - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) #move the key outwards - transformer_param_shapes = transformer.init_weights( - in_channels, jax.random.PRNGKey(0), model_config['caption_channels'], eval_only=True) - weights_init_fn = functools.partial( - transformer.init_weights, - in_channels, - jax.random.PRNGKey(0), - model_config['caption_channels'], - eval_only=True - ) - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - transformer_state = jax.device_put( - transformer_state, transformer_state_shardings) - get_memory_allocations() - - return transformer, transformer_state, transformer_state_shardings - - @classmethod - def load_vae(cls, ckpt_path): - - torch_vae = CausalVideoAutoencoder.from_pretrained(ckpt_path, torch_dtype = torch.bfloat16) - with default_env(): - torch_vae = torch_vae.to('jax') - jax_vae = TorchaxCausalVideoAutoencoder(torch_vae) - return jax_vae - - @classmethod - def load_text_encoder(cls, ckpt_path): - # text_encoder = T5EncoderModel.from_pretrained( - # ckpt_path, subfolder="text_encoder" - # ) - t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) - cpus = jax.devices("cpu") - t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) - return t5_encoder - - @classmethod - def load_tokenizer(cls, config, ckpt_path): - # tokenizer = T5Tokenizer.from_pretrained( - # ckpt_path, subfolder="tokenizer" - # ) - t5_tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, max_length=config.max_sequence_length, use_fast=True - ) - return t5_tokenizer - - @classmethod - def load_prompt_enhancement(cls, config): - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", - ) - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, - ) - return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer - - @classmethod - def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - transformer, transformer_state, transformer_state_shardings = cls.load_transformer( - config) - - # load from pytorch version - models_dir = "/mnt/disks/diffusionproj" #edit this! - ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" - if not os.path.isfile(ltxv_model_name_or_path): - ltxv_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=ltxv_model_name_or_path, - local_dir=models_dir, - repo_type="model", - ) - else: - ltxv_model_path = ltxv_model_name_or_path - - scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) - - vae = cls.load_vae(ltxv_model_path) - text_encoder = cls.load_text_encoder( - config.text_encoder_model_name_or_path) - patchifier = SymmetricPatchifier(patch_size=1) - tokenizer = cls.load_tokenizer( - config, config.text_encoder_model_name_or_path) - - enhance_prompt = False - if enhance_prompt: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( - config) - else: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None - - return LTXVideoPipeline( - transformer=transformer, - scheduler=scheduler, - scheduler_state=scheduler_state, - vae=vae, - text_encoder=text_encoder, - patchifier=patchifier, - tokenizer=tokenizer, - prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model=prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, - devices_array=devices_array, - mesh=mesh, - config=config, - transformer_state=transformer_state, - transformer_state_shardings=transformer_state_shardings - ) - - @classmethod - def _text_preprocessing(self, text): - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - text = text.strip() - return text - - return [process(t) for t in text] - def denoising_step( - scheduler, - latents: Array, - noise_pred: Array, - current_timestep: Optional[Array], - conditioning_mask: Optional[Array], - t: float, - extra_step_kwargs: Dict, - t_eps: float = 1e-6, - stochastic_sampling: bool = False, - ) -> Array: - """ - Perform the denoising step for the required tokens, based on the current timestep and - conditioning mask: - Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) - and will start to be denoised when the current timestep is equal or lower than their - conditioning timestep. - (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) - """ - # Denoise the latents using the scheduler - denoised_latents = scheduler.step( - noise_pred, - t if current_timestep is None else current_timestep, - latents, - **extra_step_kwargs, - stochastic_sampling=stochastic_sampling, - ) - - if conditioning_mask is None: - return denoised_latents - - tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) - tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) - return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) - - def retrieve_timesteps( #currently doesn't support custom timesteps - self, - scheduler: FlaxRectifiedFlowMultistepScheduler, - latent_shape, - scheduler_state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - timesteps: Optional[List[int]] = None, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) + + models_dir = config.models_dir + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + + vae = cls.load_vae(ltxv_model_path) + text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = cls.load_prompt_enhancement(config) + else: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = (None, None, None, None) + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings, + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( # currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps( + state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps + ) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps # Use the original num_inference_steps here for the check ): - scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) - timesteps = scheduler_state.timesteps - if ( - skip_initial_inference_steps < 0 - or skip_final_inference_steps < 0 - or skip_initial_inference_steps + skip_final_inference_steps - >= num_inference_steps # Use the original num_inference_steps here for the check - ): - raise ValueError( - "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" - ) - timesteps = timesteps[ - skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps - ] - scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = text_encoder_max_tokens # TPU supports only lengths multiple of 128 + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) + prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( # currently no support for media item encoding, since encoder isn't tested + self, + latents: Optional[jnp.ndarray], + timestep: float, + latent_shape: Tuple[Any, ...], + dtype: jnp.dtype, + key: jax.random.PRNGKey, + ) -> jnp.ndarray: + """ + Prepares initial latents for a diffusion process, potentially encoding media items + or adding noise + """ + if latents is not None: + assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate noise in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + + # Generate noise using jax.random.normal + noise_intermediate_shape = (b, f * h * w, c) + noise = jax.random.normal(key, noise_intermediate_shape, dtype=dtype) + + # Rearrange "b (f h w) c -> b c f h w" + # Step 1: Reshape to (b, f, h, w, c) + noise = noise.reshape(b, f, h, w, c) + # Step 2: Permute/Transpose to (b, c, f, h, w) + noise = jnp.transpose(noise, (0, 4, 1, 2, 3)) # Old (b,f,h,w,c) -> New (b,c,f,h,w) + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + # Ensure timestep is a jnp.array with the correct dtype + timestep_array = jnp.array(timestep, dtype=dtype) + latents = timestep_array * noise + (1 - timestep_array) * latents + + return latents + + def prepare_conditioning( # no support for conditioning items, conditioning mask, needs to conver to torch before patchifier + self, + init_latents: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: + assert isinstance(self.vae, TorchaxCausalVideoAutoencoder) + init_latents = torch.from_numpy(np.array(init_latents)) + init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) + init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=True) + return ( + jnp.array(init_latents.to(torch.float32).detach().numpy()), + jnp.array(init_pixel_coords.to(torch.float32).detach().numpy()), + 0, + ) + + + def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Borrowed from diffusers.image_processor + Denormalize an image array to [0,1]. - return scheduler_state - - def encode_prompt( - self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - text_encoder_max_tokens: int = 256, - **kwargs, - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - max_length = ( - text_encoder_max_tokens # TPU supports only lengths multiple of 128 - ) - if prompt_embeds is None: - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," - # text_enc_device = next(self.text_encoder.parameters()) - prompt = self._text_preprocessing(prompt) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, max_length - 1: -1] - ) - - prompt_attention_mask = jnp.array(text_inputs.attention_mask) - prompt_embeds = self.text_encoder( - text_input_ids, attention_mask=prompt_attention_mask - ) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype - else: - dtype = None - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - print(isinstance(prompt_embeds, Array)) - # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - # prompt_embeds = prompt_embeds.view( - # bs_embed * num_images_per_prompt, seq_len, -1 - # ) - prompt_embeds = jnp.reshape( - prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) - prompt_attention_mask = jnp.tile( - prompt_attention_mask, (1, num_images_per_prompt)) - # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) - # prompt_attention_mask = prompt_attention_mask.view( - # bs_embed * num_images_per_prompt, -1 - # ) - prompt_attention_mask = jnp.reshape( - prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - - # get unconditional embeddings for classifier free guidance hasn't changed yet - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = self._text_preprocessing(negative_prompt) - uncond_tokens = uncond_tokens * batch_size - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) - - negative_prompt_embeds = self.text_encoder( - jnp.array(uncond_input.input_ids), - attention_mask=negative_prompt_attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = jnp.tile(negative_prompt_embeds, - (1, num_images_per_prompt, 1) - ) - negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, - (batch_size * num_images_per_prompt, seq_len, -1) - ) - - negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, - (1, num_images_per_prompt) - ) - negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, - (bs_embed * num_images_per_prompt, -1) - ) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - data_sharding = jax.sharding.NamedSharding(self.mesh, P('data')) - return ( - jax.device_put(prompt_embeds, data_sharding),#(1, 256, 4096) - jax.device_put(prompt_attention_mask, data_sharding), #1, 256 - jax.device_put(negative_prompt_embeds, data_sharding), - jax.device_put(negative_prompt_attention_mask, data_sharding) - ) - - def prepare_latents( - self, - latents: torch.Tensor | None, - media_items: torch.Tensor | None, - timestep: float, - latent_shape: torch.Size | Tuple[Any, ...], - dtype: torch.dtype, - device: torch.device, - generator: torch.Generator | List[torch.Generator], - vae_per_channel_normalize: bool = True, - ): - if isinstance(generator, list) and len(generator) != latent_shape[0]: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." - ) - - # Initialize the latents with the given latents or encoded media item, if provided - assert ( - latents is None or media_items is None - ), "Cannot provide both latents and media_items. Please provide only one of the two." - - assert ( - latents is None and media_items is None or timestep < 1.0 - ), "Input media_item or latents are provided, but they will be replaced with noise." - - if media_items is not None: - latents = self.vae.encode( - media_itmes = media_items, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - if latents is not None: - assert ( - latents.shape == latent_shape - ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." - - # For backward compatibility, generate in the "patchified" shape and rearrange - b, c, f, h, w = latent_shape - noise = randn_tensor( ###need to remove this! - (b, f * h * w, c), generator=generator, device=device, dtype=dtype - ) - noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) - - # scale the initial noise by the standard deviation required by the scheduler - # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have - - if latents is None: - latents = noise - else: - # Noise the latents to the required (first) timestep - timestep = torch.from_numpy(np.array(timestep)) - latents = timestep * noise + (1 - timestep) * latents - - return latents - - def prepare_conditioning( - self, - conditioning_items, - init_latents: torch.Tensor, - num_frames: int, - height: int, - width: int, - vae_per_channel_normalize: bool = True, - generator=None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - - assert isinstance(self.vae, TorchaxCausalVideoAutoencoder) - - # if conditioning_items: - # batch_size, _, num_latent_frames = init_latents.shape[:3] - - # init_conditioning_mask = torch.zeros( - # init_latents[:, 0, :, :, :].shape, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents = [] - # extra_conditioning_pixel_coords = [] - # extra_conditioning_mask = [] - # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) - - # # Process each conditioning item - # for conditioning_item in conditioning_items: - # conditioning_item = self._resize_conditioning_item( - # conditioning_item, height, width - # ) - # media_item = conditioning_item.media_item - # media_frame_number = conditioning_item.media_frame_number - # strength = conditioning_item.conditioning_strength - # assert media_item.ndim == 5 # (b, c, f, h, w) - # b, c, n_frames, h, w = media_item.shape - # assert ( - # height == h and width == w - # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - # assert n_frames % 8 == 1 - # assert ( - # media_frame_number >= 0 - # and media_frame_number + n_frames <= num_frames - # ) - - # # Encode the provided conditioning media item - # media_item_latents = vae_encode( - # media_item.to(dtype=self.vae.dtype, device=self.vae.device), - # self.vae, - # vae_per_channel_normalize=vae_per_channel_normalize, - # ).to(dtype=init_latents.dtype) - - # # Handle the different conditioning cases - # if media_frame_number == 0: - # # Get the target spatial position of the latent conditioning item - # media_item_latents, l_x, l_y = self._get_latent_spatial_position( - # media_item_latents, - # conditioning_item, - # height, - # width, - # strip_latent_border=True, - # ) - # b, c_l, f_l, h_l, w_l = media_item_latents.shape - - # # First frame or sequence - just update the initial noise latents and the mask - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( - # torch.lerp( - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], - # media_item_latents, - # strength, - # ) - # ) - # init_conditioning_mask[ - # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l - # ] = strength - # else: - # # Non-first frame or sequence - # if n_frames > 1: - # # Handle non-first sequence. - # # Encoded latents are either fully consumed, or the prefix is handled separately below. - # ( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # ) = self._handle_non_first_conditioning_sequence( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # media_frame_number, - # strength, - # ) - - # # Single frame or sequence-prefix latents - # if media_item_latents is not None: - # noise = randn_tensor( - # media_item_latents.shape, - # generator=generator, - # device=media_item_latents.device, - # dtype=media_item_latents.dtype, - # ) - - # media_item_latents = torch.lerp( - # noise, media_item_latents, strength - # ) - - # # Patchify the extra conditioning latents and calculate their pixel coordinates - # media_item_latents, latent_coords = self.patchifier.patchify( - # latents=media_item_latents - # ) - # pixel_coords = latent_to_pixel_coords( - # latent_coords, - # self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, - # ) - - # # Update the frame numbers to match the target frame number - # pixel_coords[:, 0] += media_frame_number - # extra_conditioning_num_latents += media_item_latents.shape[1] - - # conditioning_mask = torch.full( - # media_item_latents.shape[:2], - # strength, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents.append(media_item_latents) - # extra_conditioning_pixel_coords.append(pixel_coords) - # extra_conditioning_mask.append(conditioning_mask) - - # Patchify the updated latents and calculate their pixel coordinates - init_latents, init_latent_coords = self.patchifier.patchify( - latents=init_latents - ) - init_pixel_coords = latent_to_pixel_coords( - init_latent_coords, - self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now - causal_fix=True - - ) - - if not conditioning_items: - return init_latents, init_pixel_coords, None, 0 - - # init_conditioning_mask, _ = self.patchifier.patchify( - # latents=init_conditioning_mask.unsqueeze(1) - # ) - # init_conditioning_mask = init_conditioning_mask.squeeze(-1) - - # if extra_conditioning_latents: - # # Stack the extra conditioning latents, pixel coordinates and mask - # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) - # init_pixel_coords = torch.cat( - # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 - # ) - # init_conditioning_mask = torch.cat( - # [*extra_conditioning_mask, init_conditioning_mask], dim=1 - # ) - - # if self.transformer.use_tpu_flash_attention: - # # When flash attention is used, keep the original number of tokens by removing - # # tokens from the end. - # init_latents = init_latents[:, :-extra_conditioning_num_latents] - # init_pixel_coords = init_pixel_coords[ - # :, :, :-extra_conditioning_num_latents - # ] - # init_conditioning_mask = init_conditioning_mask[ - # :, :-extra_conditioning_num_latents - # ] - - # return ( - # init_latents, - # init_pixel_coords, - # init_conditioning_mask, - # extra_conditioning_num_latents, - # ) - - # change the paramters of these, currently pass in dummy inputs + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. + """ + return (images * 0.5 + 0.5).clamp(0, 1) + + def _denormalize_conditionally(self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None) -> torch.Tensor: + r""" + Borrowed from diffusers.image_processor + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) + + return torch.stack([self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]) + + # same as diffusers image_processor.postprocess + + def postprocess_to_output_type(self, image, output_type): + ''' + Currrently supporting output type latent and pt + ''' + if not isinstance(image, torch.Tensor): + raise ValueError(f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor") + + if output_type not in ["latent", "pt", "np", "pil"]: + output_type = "np" + + if output_type == "latent": + return image + image = self._denormalize_conditionally(image, None) + if output_type == "pt": + return image + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + frame_rate: int = 30, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, TorchaxCausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps( + self.scheduler, + latent_shape, + self.scheduler_state, + num_inference_steps, + None, + skip_initial_inference_steps, + skip_final_inference_steps, + ) + + #set up guidance + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] + guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + + #prepare skip block list + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) + for skip_blocks in skip_block_list + ] - #copied from image_processor ?? change?? move to a seperate file???? - - def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - r""" - Denormalize an image array to [0,1]. - - Args: - images (`np.ndarray` or `torch.Tensor`): - The image array to denormalize. - - Returns: - `np.ndarray` or `torch.Tensor`: - The denormalized image array. - """ - return (images * 0.5 + 0.5).clamp(0, 1) + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask - def _denormalize_conditionally( - self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None - ) -> torch.Tensor: - r""" - Denormalize a batch of images based on a condition list. - - Args: - images (`torch.Tensor`): - The input image tensor. - do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): - A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the - value of `do_normalize` in the `VaeImageProcessor` config. - """ - if do_denormalize is None: - return self.denormalize(images) - - return torch.stack( - [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])] - ) - #same as diffusers image_processor.postprocess + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) - def postprocess_to_output_type(self, image, output_type): #support latent & pt - if not isinstance(image, torch.Tensor): - raise ValueError( - f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" - ) - - if output_type not in ["latent", "pt", "np", "pil"]: - output_type = "np" - - if output_type == "latent": - return image - image = self._denormalize_conditionally(image, None) #do denormalize set to none - if output_type == "pt": - return image - + #optionally pass in a latent here + latents = self.prepare_latents( + latents=latents, + timestep=scheduler_state.timesteps[0], + latent_shape=latent_shape, + dtype=jnp.float32, + key=jax.random.PRNGKey(0), + ) + + latents, pixel_coords, num_cond_latents = self.prepare_conditioning( + init_latents=latents, + ) + + pixel_coords = jnp.concatenate([pixel_coords] * num_conds, axis=0) + fractional_coords = pixel_coords.astype(jnp.float32) + fractional_coords = fractional_coords.at[:, 0].set(fractional_coords[:, 0] * (1.0 / frame_rate)) + + # initialize dummy noise + noise_cond = jnp.ones((1, 1)) + + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling=do_rescaling, + rescaling_scale=rescaling_scale, + batch_size=batch_size, + skip_layer_masks=skip_layer_masks, + skip_layer_strategy=skip_layer_strategy, + cfg_star_rescale=cfg_star_rescale, + ) + + with self.mesh: + latents, scheduler_state = p_run_inference( + transformer_state=self.transformer_state, latents=latents, timestep=noise_cond, scheduler_state=scheduler_state + ) - def __call__( - self, - height: int, - width: int, - num_frames: int, - negative_prompt: str = "", - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - frame_rate: int = 30, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - guidance_timesteps: Optional[List[int]] = None, - decode_timestep: Union[List[float], float] = 0.05, - decode_noise_scale: Optional[List[float]] = 0.025, - offload_to_cpu: bool = False, - enhance_prompt: bool = False, - text_encoder_max_tokens: int = 256, - num_inference_steps: int = 50, - guidance_scale: Union[float, List[float]] = 4.5, - rescaling_scale: Union[float, List[float]] = 0.7, - stg_scale: Union[float, List[float]] = 1.0, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - cfg_star_rescale: bool = False, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, - **kwargs, - ): - enhance_prompt = False - prompt = self.config.prompt - is_video = kwargs.get("is_video", False) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - vae_per_channel_normalize = kwargs.get( - "vae_per_channel_normalize", True) - - - latent_height = height // self.vae_scale_factor - latent_width = width // self.vae_scale_factor - latent_num_frames = num_frames // self.video_scale_factor - if isinstance(self.vae, TorchaxCausalVideoAutoencoder) and is_video: - latent_num_frames += 1 - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - - latent_shape = ( - batch_size * num_images_per_prompt, - model_config["in_channels"], - latent_num_frames, - latent_height, - latent_width, - ) - scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) - - - guidance_mapping = [] - - if guidance_timesteps: - for timestep in scheduler_state.timesteps: - indices = [ - i for i, val in enumerate(guidance_timesteps) if val <= timestep - ] - guidance_mapping.append( - indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) - ) - - if not isinstance(guidance_scale, list): - guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) - else: - guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(stg_scale, list): - stg_scale = [stg_scale] * len(scheduler_state.timesteps) - else: - stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(rescaling_scale, list): - rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) - else: - rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) - do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) - do_rescaling = any(x != 1.0 for x in rescaling_scale) - - - num_conds = 1 - if do_classifier_free_guidance: - num_conds += 1 - if do_spatio_temporal_guidance: - num_conds += 1 - - is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) - - if not is_list_of_lists: - skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) - else: - new_skip_block_list = [] - for i in range(len(scheduler_state.timesteps)): - new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) - - skip_block_list = new_skip_block_list - - if do_spatio_temporal_guidance: - if skip_block_list is not None: - skip_layer_masks = [ - self.transformer.create_skip_layer_mask( - batch_size, num_conds, num_conds - 1, skip_blocks - ) - for skip_blocks in skip_block_list - ] - if enhance_prompt: - prompt = generate_cinematic_prompt( - self.prompt_enhancer_image_caption_model, - self.prompt_enhancer_image_caption_processor, - self.prompt_enhancer_llm_model, - self.prompt_enhancer_llm_tokenizer, - prompt, - None, # conditioning items set to None - max_new_tokens=text_encoder_max_tokens, - ) - - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=None, # device set to none - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_max_tokens=text_encoder_max_tokens, - ) - prompt_embeds_batch = prompt_embeds - prompt_attention_mask_batch = prompt_attention_mask - if do_classifier_free_guidance: - prompt_embeds_batch = jnp.concatenate( - [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension - ) - prompt_attention_mask_batch = jnp.concatenate( - [negative_prompt_attention_mask, prompt_attention_mask], axis=0 - ) - if do_spatio_temporal_guidance: - prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [ - prompt_attention_mask_batch, - prompt_attention_mask, - ], - axis=0, - ) - latents = self.prepare_latents( - latents=latents, - media_items=None, # set to None - timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this - latent_shape=latent_shape, - dtype=None, - device=None, - generator=generator, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - - latents, pixel_coords, conditioning_mask, num_cond_latents = ( - self.prepare_conditioning( - conditioning_items=None, - init_latents=latents, - num_frames=num_frames, - height=height, - width=width, - vae_per_channel_normalize=vae_per_channel_normalize, - generator=generator, - ) - ) - - extra_step_kwargs = prepare_extra_step_kwargs( - generator=jax.random.PRNGKey(0)) - - pixel_coords = torch.cat([pixel_coords] * num_conds) - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - - # if not isinstance(guidance_scale, List): - # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) - # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - # latents = jax.device_put(example_inputs["latents"], data_sharding) - # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) - noise_cond = jnp.ones( # initialize first round with this! - (4, 1) - ) - - # # noise_cond = None - # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" - # tensor_dict = torch.load(saved_tensor_path) - - # for key, value in tensor_dict.items(): - # if value is not None: - # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) - # example_inputs = tensor_dict - # latents = jax.device_put(example_inputs["latent_model_input"]) - # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) - # fractional_coords = jax.device_put(example_inputs["indices_grid"]) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) - # segment_ids = None - # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) - - # #only run this for the first time! - # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) - # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken - # scheduler_state = self.scheduler_state - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = jax.jit( - # functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # ), - # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct - # out_shardings=None, - # ) - segment_ids = None - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # # guidance_scale=guidance_scale - # ) - p_run_inference = functools.partial( - run_inference, - transformer=self.transformer, - config=self.config, - mesh=self.mesh, - fractional_cords=jnp.array( - fractional_coords.to(torch.float32).detach().numpy()), - prompt_embeds=prompt_embeds_batch, - segment_ids=None, - encoder_attention_segment_ids=prompt_attention_mask_batch, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - do_classifier_free_guidance=do_classifier_free_guidance, - num_conds=num_conds, - guidance_scale=guidance_scale, - do_spatio_temporal_guidance=do_spatio_temporal_guidance, - stg_scale=stg_scale, - do_rescaling = do_rescaling, - rescaling_scale = rescaling_scale, - batch_size=batch_size, - skip_layer_masks = skip_layer_masks, - skip_layer_strategy = skip_layer_strategy, - cfg_star_rescale = cfg_star_rescale - ) - - with self.mesh: - latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( - # add scheduler state back in - torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) - # latents = torch.from_numpy(np.array(latents)) - latents = latents[:, num_cond_latents:] - - latents = self.patchifier.unpatchify( - latents=latents, - output_height=latent_height, - output_width=latent_width, - out_channels=model_config["in_channels"] - // math.prod(self.patchifier.patch_size), - ) - - if output_type != "latent": - if self.vae.decoder.timestep_conditioning: - noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) #move the key to outer layer - - # Convert decode_timestep to a list if it's not already one - if not isinstance(decode_timestep, (list, jnp.ndarray)): - decode_timestep = [decode_timestep] * latents.shape[0] - - # Handle decode_noise_scale - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): - decode_noise_scale = [decode_noise_scale] * latents.shape[0] - - # Convert lists to JAX arrays - decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) - - # Reshape decode_noise_scale for broadcasting - decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) - decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) - - # Apply the noise and scale - latents = ( - latents * (1 - decode_noise_scale) + - noise * decode_noise_scale - ) - else: - decode_timestep = None - image = self.vae.decode( - latents = jax.device_put(latents, jax.devices('tpu')[0]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]), - is_video = is_video, - vae_per_channel_normalize=kwargs.get( - "vae_per_channel_normalize", True), - timestep=decode_timestep #.astype(jnp.bfloat16), - ) - image = self.postprocess_to_output_type( #swap this out! - torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type) - - else: - image = latents - - # Offload all models - - if not return_dict: - return (image,) - - return image - # save states here - - -def transformer_forward_pass( # need to jit this? wan didnt + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), + ) + + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) + + # Convert decode_timestep to a list if it's not already one + if not isinstance(decode_timestep, (list, jnp.ndarray)): + decode_timestep = [decode_timestep] * latents.shape[0] + + # Handle decode_noise_scale + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + # Convert lists to JAX arrays + decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) + + # Reshape decode_noise_scale for broadcasting + decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32) + decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1)) + + # Apply the noise and scale + latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale + else: + decode_timestep = None + image = self.vae.decode( + latents=jax.device_put(latents, jax.devices("tpu")[0]), + is_video=is_video, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + # convert back to torch to post process using the diffusers library + image = self.postprocess_to_output_type( + torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type + ) + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + + + +def transformer_forward_pass( latents, state, noise_cond, @@ -1217,244 +822,241 @@ def transformer_forward_pass( # need to jit this? wan didnt skip_layer_strategy, ): - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy - ) # need .param here? - return noise_pred, state + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + ) + return noise_pred, state def run_inference( - transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale + transformer_state, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + scheduler, + segment_ids, + encoder_attention_segment_ids, + scheduler_state, + do_classifier_free_guidance, + num_conds, + guidance_scale, + do_spatio_temporal_guidance, + stg_scale, + do_rescaling, + rescaling_scale, + batch_size, + skip_layer_masks, + skip_layer_strategy, + cfg_star_rescale, ): - # do_classifier_free_guidance = guidance_scale > 1.0 - # for step in range(num_inference_steps): - for i, t in enumerate(scheduler_state.timesteps): - current_timestep = t - latent_model_input = ( - jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents - ) - # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) - # timestep = jnp.broadcast_to( - # t, (latent_model_input.shape[0],) - # ).reshape(-1, 1) - if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - # Determine the correct dtype based on the device (similar to PyTorch) - is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. - if isinstance(current_timestep, float): - dtype = jnp.float32 - else: - dtype = jnp.int32 - - current_timestep = jnp.array( - [current_timestep], - dtype=dtype, - ) - elif current_timestep.ndim == 0: - current_timestep = jnp.expand_dims(current_timestep, axis=0) + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) # Broadcast to batch dimension (compatible with ONNX/Core ML) - current_timestep = jnp.broadcast_to( - current_timestep, (latent_model_input.shape[0],1) - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( - skip_layer_masks[i] - if skip_layer_masks is not None - else None - ), skip_layer_strategy = skip_layer_strategy) - # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) - - # # latents = self.denoising - # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated - if do_spatio_temporal_guidance: - # JAX uses jnp.split for splitting along an axis. - # Equivalent to .chunk() for equal-sized chunks - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_text = chunks[-2] - noise_pred_text_perturb = chunks[-1] - - if do_classifier_free_guidance: - - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_uncond = chunks[0] - noise_pred_text = chunks[1] - if cfg_star_rescale: - positive_flat = noise_pred_text.reshape(batch_size, -1) - negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum( #(1, 1) - positive_flat * negative_flat, axis=1, keepdims=True - ) - squared_norm = ( #(1, 1) - jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 - ) - alpha = dot_product / squared_norm #might need to reshape this (1, 1) - alpha = alpha.reshape(batch_size, 1, 1) - - noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) - noise_pred = noise_pred_uncond + guidance_scale[i] * ( - noise_pred_text - noise_pred_uncond - ) - elif do_spatio_temporal_guidance: - noise_pred = noise_pred_text - - if do_spatio_temporal_guidance: - noise_pred = noise_pred + stg_scale[i] * ( - noise_pred_text - noise_pred_text_perturb - ) - if do_rescaling and stg_scale[i] > 0.0: - noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) - noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) - - factor = noise_pred_text_std / noise_pred_std - factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) - - - noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] # JAX slicing is similar - latents, scheduler_state = scheduler.step( - scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() - - return latents, scheduler_state - -def adain_filter_latent( - latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 -): - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor. - - Args: - latent (torch.Tensor): Input latents to normalize - reference_latent (torch.Tensor): The reference latents providing style statistics. - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 + current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, + transformer_state, + current_timestep, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), + skip_layer_strategy=skip_layer_strategy, + ) + + #perform guidance on noise prediction + if do_spatio_temporal_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.reshape(batch_size, 1, 1) + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + + + + +#up to here +def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + with default_env(): + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result - Returns: - torch.Tensor: The transformed latent tensor - """ - with default_env(): - result = latents.clone() - for i in range(latents.size(0)): - for c in range(latents.size(1)): - r_sd, r_mean = torch.std_mean( - reference_latents[i, c], dim=None - ) # index by original dim order - i_sd, i_mean = torch.std_mean(result[i, c], dim=None) +class LTXMultiScalePipeline: ##figure these methods out - result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): + latents = jax.device_put(latents, jax.devices("tpu")[0]) + # assert latents.device == latest_upsampler.device + with default_env(): + latents = un_normalize_latents( # need to switch this out? + interop.torch_view(latents), self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler( + torch.from_numpy(np.array(latents)) + ) # here converted back to torch, cause upsampler in pytorch + upsampled_latents = normalize_latents( + interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents + + def __init__(self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + self.latent_upsampler = latent_upsampler + + def __call__(self, height, width, num_frames, is_video, output_type, generator, config) -> Any: + + original_output_type = output_type + original_width = width + original_height = height + x_width = int(width * config.downscale_factor) + downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) + x_height = int(height * config.downscale_factor) + downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) + # #use original height and width here to see + output_type = "latent" + result = self.video_pipeline( + height=original_height, + width=original_width, + num_frames=num_frames, + is_video=True, + output_type=output_type, + generator=generator, + guidance_scale=config.first_pass["guidance_scale"], + stg_scale=config.first_pass["stg_scale"], + rescaling_scale=config.first_pass["rescaling_scale"], + skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], + num_inference_steps=config.first_pass["num_inference_steps"], + guidance_timesteps=config.first_pass["guidance_timesteps"], + cfg_star_rescale=config.first_pass["cfg_star_rescale"], + skip_layer_strategy=None, + skip_block_list=config.first_pass["skip_block_list"], + ) + latents = result + print("done") + upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) # convert back to pytorch here + ##maybe change this? + latents = torch.from_numpy(np.array(latents)) # .to(torch.device('cpu')) + upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) # .to(torch.device('cpu')) + upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + latents = upsampled_latents + output_type = original_output_type + width = downscaled_width * 2 + height = downscaled_height * 2 + # import pdb; pdb.set_trace() + latents = jnp.array(latents) + result = self.video_pipeline( + height=original_height * 2, + width=original_width * 2, + num_frames=num_frames, + is_video=True, + output_type=original_output_type, + latents=latents, + generator=generator, + guidance_scale=config.second_pass["guidance_scale"], + stg_scale=config.second_pass["stg_scale"], + rescaling_scale=config.second_pass["rescaling_scale"], + skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], + num_inference_steps=config.second_pass["num_inference_steps"], + guidance_timesteps=config.second_pass["guidance_timesteps"], + cfg_star_rescale=config.second_pass["cfg_star_rescale"], + skip_layer_strategy=None, + skip_block_list=config.second_pass["skip_block_list"], + ) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(original_height, original_width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos - result = torch.lerp(latents, result, factor) return result - -class LTXMultiScalePipeline: ##figure these methods out - - @classmethod - def load_latent_upsampler(cls, config): - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir=config.models_dir, - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) - latent_upsampler.eval() - return latent_upsampler - - def _upsample_latents( - self, latest_upsampler: LatentUpsampler, latents: torch.Tensor - ): - latents = jax.device_put(latents, jax.devices('tpu')[0]) - #assert latents.device == latest_upsampler.device - with default_env(): - latents = un_normalize_latents( #need to switch this out? - interop.torch_view(latents), self.vae, vae_per_channel_normalize=True - ) - upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) #here converted back to torch, cause upsampler in pytorch - upsampled_latents = normalize_latents( - interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True - ) - return upsampled_latents - - def __init__( - self, video_pipeline: LTXVideoPipeline - ): - self.video_pipeline = video_pipeline - self.vae = video_pipeline.vae - - def __call__( - self, - height, - width, - num_frames, - output_type, - generator, - config - ) -> Any: - original_output_type = output_type - original_width = width - original_height = height - x_width = int(width * config.downscale_factor) - downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) - x_height = int(height * config.downscale_factor) - downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) - # #use original height and width here to see - output_type = 'latent' - result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, - is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], - num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) - latents = result - print("done") - latent_upsampler = self.load_latent_upsampler(config) - upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here - ##maybe change this? - latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu')) - upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu')) - upsampled_latents = adain_filter_latent( - latents=upsampled_latents, reference_latents=latents - ) - latents = upsampled_latents - output_type = original_output_type - width = downscaled_width * 2 - height = downscaled_height * 2 - # import pdb; pdb.set_trace() - result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, - is_video=True, output_type=original_output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], - num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) - - if original_output_type != "latent": - num_frames = result.shape[2] - videos = rearrange(result, "b c f h w -> (b f) c h w") - - videos = F.interpolate( - videos, - size=(original_height, original_width), - mode="bilinear", - align_corners=False, - ) - videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) - result = videos - - return result \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py index b550aeea3..6c39730df 100644 --- a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -295,7 +295,7 @@ def step( timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) if timestep.ndim == 0: - idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") #noqa: F841 + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") # noqa: F841 current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), timesteps_padded[current_t_idx + 1], 0.0) dt = timestep - lower_timestep From 0577d3ebe1b247c8b6e42d641de18808b47db4ac Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 22 Jul 2025 02:04:25 +0000 Subject: [PATCH 56/69] pipeline cleaned --- src/maxdiffusion/configs/ltx_video.yml | 10 +- src/maxdiffusion/generate_ltx_video.py | 66 +---- .../pipelines/ltx_video/ltx_video_pipeline.py | 230 ++++++++++-------- 3 files changed, 149 insertions(+), 157 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 1259d2dab..5798d1c8a 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -19,16 +19,12 @@ frame_rate: 30 max_sequence_length: 512 sampler: "from_checkpoint" - - - - # Generation parameters -pipeline_type: None -prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." +pipeline_type: multi-scale +prompt: "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage." height: 512 width: 512 -num_frames: 88 #344 +num_frames: 344 #344 flow_shift: 5.0 fps: 24 downscale_factor: 0.6666666 diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 5383ab3cd..11d18bf77 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,16 +4,12 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -import jax.numpy as jnp from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from huggingface_hub import hf_hub_download import imageio from datetime import datetime -from maxdiffusion.utils import export_to_video import os -import json import torch from pathlib import Path @@ -62,25 +58,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: return "-".join(result) -def create_latent_upsampler(latent_upsampler_model_path: str, device: str): - latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) - latent_upsampler.to(device) - latent_upsampler.eval() - return latent_upsampler def get_unique_filename( base: str, ext: str, prompt: str, - seed: int, resolution: tuple[int, int, int], dir: Path, endswith=None, index_range=1000, ) -> Path: base_filename = ( - f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}" ) for i in range(index_range): filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" @@ -94,55 +84,23 @@ def run(config): width_padded = ((config.width - 1) // 32 + 1) * 32 num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 padding = calculate_padding(config.height, config.width, height_padded, width_padded) - # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold - # prompt_word_count = len(config.prompt.split()) - # enhance_prompt = ( - # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold - # ) - - seed = 10 # change this, generator in pytorch, used in prepare_latents - generator = torch.Generator().manual_seed(seed) - pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False) - if config.pipeline_type == "multi-scale": # move this to pipeline file?? - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir="/mnt/disks/diffusionproj", - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = create_latent_upsampler(spatial_upscaler_model_path, "cpu") # device set to cpu for now - pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) - stg_mode = config.stg_mode - if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": - skip_layer_strategy = SkipLayerStrategy.AttentionValues - elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": - skip_layer_strategy = SkipLayerStrategy.AttentionSkip - elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": - skip_layer_strategy = SkipLayerStrategy.Residual - elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": - skip_layer_strategy = SkipLayerStrategy.TransformerBlock - else: - raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") - # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, - # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, - # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images + prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold + prompt_word_count = len(config.prompt.split()) + enhance_prompt = ( + prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold + ) + + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) + if config.pipeline_type == "multi-scale": + pipeline = LTXMultiScalePipeline(pipeline) images = pipeline( height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type="pt", - generator=generator, config=config, + enhance_prompt = False ) (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom @@ -167,7 +125,6 @@ def run(config): f"image_output_{i}", ".png", prompt=config.prompt, - seed=seed, resolution=(height, width, config.num_frames), dir=output_dir, ) @@ -177,7 +134,6 @@ def run(config): f"video_output_{i}", ".mp4", prompt=config.prompt, - seed=seed, resolution=(height, width, config.num_frames), dir=output_dir, ) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index a84dd491c..3e45c8c73 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -13,6 +13,7 @@ # limitations under the License. import math import os +from xmlrpc.client import Boolean from jax import Array from typing import Optional, List, Union, Tuple from einops import rearrange @@ -242,7 +243,6 @@ def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): patchifier = SymmetricPatchifier(patch_size=1) tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) - enhance_prompt = False if enhance_prompt: ( prompt_enhancer_image_caption_model, @@ -343,16 +343,12 @@ def retrieve_timesteps( # currently doesn't support custom timesteps return scheduler_state - def encode_prompt( + def encode_prompt( #input changed self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, text_encoder_max_tokens: int = 256, **kwargs, ): @@ -364,28 +360,28 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] max_length = text_encoder_max_tokens # TPU supports only lengths multiple of 128 - if prompt_embeds is None: - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," - prompt = self._text_preprocessing(prompt) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) - prompt_attention_mask = jnp.array(text_inputs.attention_mask) - prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0] + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) @@ -394,7 +390,7 @@ def encode_prompt( prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: + if do_classifier_free_guidance: uncond_tokens = self._text_preprocessing(negative_prompt) uncond_tokens = uncond_tokens * batch_size max_length = prompt_embeds.shape[1] @@ -546,11 +542,7 @@ def __call__( negative_prompt: str = "", num_images_per_prompt: Optional[int] = 1, frame_rate: int = 30, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + latents: Optional[jnp.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, guidance_timesteps: Optional[List[int]] = None, @@ -569,7 +561,7 @@ def __call__( skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, ): - enhance_prompt = False + import pdb; pdb.set_trace() prompt = self.config.prompt is_video = kwargs.get("is_video", False) if prompt is not None and isinstance(prompt, str): @@ -681,10 +673,6 @@ def __call__( do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, text_encoder_max_tokens=text_encoder_max_tokens, ) prompt_embeds_batch = prompt_embeds @@ -732,7 +720,6 @@ def __call__( prompt_embeds=prompt_embeds_batch, segment_ids=None, encoder_attention_segment_ids=prompt_attention_mask_batch, - num_inference_steps=num_inference_steps, scheduler=self.scheduler, do_classifier_free_guidance=do_classifier_free_guidance, num_conds=num_conds, @@ -749,7 +736,7 @@ def __call__( with self.mesh: latents, scheduler_state = p_run_inference( - transformer_state=self.transformer_state, latents=latents, timestep=noise_cond, scheduler_state=scheduler_state + transformer_state=self.transformer_state, latents=latents, scheduler_state=scheduler_state ) latents = latents[:, num_cond_latents:] @@ -934,41 +921,104 @@ def run_inference( #up to here -def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor. - - Args: - latent (torch.Tensor): Input latents to normalize - reference_latent (torch.Tensor): The reference latents providing style statistics. - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 - - Returns: - torch.Tensor: The transformed latent tensor - """ - with default_env(): - result = latents.clone() - - for i in range(latents.size(0)): - for c in range(latents.size(1)): - r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order - i_sd, i_mean = torch.std_mean(result[i, c], dim=None) +# def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): +# """ +# Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on +# statistics from a reference latent tensor. + +# Args: +# latent (torch.Tensor): Input latents to normalize +# reference_latent (torch.Tensor): The reference latents providing style statistics. +# factor (float): Blending factor between original and transformed latent. +# Range: -10.0 to 10.0, Default: 1.0 + +# Returns: +# torch.Tensor: The transformed latent tensor +# """ +# with default_env(): +# result = latents.clone() + +# for i in range(latents.size(0)): +# for c in range(latents.size(1)): +# r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order +# i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + +# result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + +# result = torch.lerp(latents, result, factor) +# return result + +def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, factor: float = 1.0) -> jnp.ndarray: + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor, implemented in JAX. - result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + Args: + latents (jax.Array): Input latents to normalize. Expected shape (B, C, F, H, W). + reference_latents (jax.Array): The reference latents providing style statistics. + Expected shape (B, C, F, H, W). + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 - result = torch.lerp(latents, result, factor) - return result + Returns: + jax.Array: The transformed latent tensor. + """ + with default_env(): + latents = jax.device_put(latents, jax.devices("tpu")[0]) + reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) + # Define the core AdaIN operation for a single (F, H, W) slice. + # This function will be vmapped over batch (B) and channel (C) dimensions. + def _adain_single_slice(latent_slice: jnp.ndarray, ref_latent_slice: jnp.ndarray) -> jnp.ndarray: + """ + Applies AdaIN to a single latent slice (F, H, W) based on a reference slice. + """ + r_mean = jnp.mean(ref_latent_slice) + r_sd = jnp.std(ref_latent_slice) + + # Calculate standard deviation and mean for the input latent slice + i_mean = jnp.mean(latent_slice) + i_sd = jnp.std(latent_slice) + i_sd_safe = jnp.where(i_sd < 1e-6, 1.0, i_sd) + normalized_latent = (latent_slice - i_mean) / i_sd_safe + transformed_latent_slice = normalized_latent * r_sd + r_mean + return transformed_latent_slice + + transformed_latents_core = jax.vmap( + jax.vmap(_adain_single_slice, in_axes=(0, 0)), + in_axes=(0, 0) # Vmap over batch (axis 0) + )(latents, reference_latents) + result_blended = latents * (1.0 - factor) + transformed_latents_core * factor + + return result_blended + + +class LTXMultiScalePipeline: + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler -class LTXMultiScalePipeline: ##figure these methods out - def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: jnp.ndarray): latents = jax.device_put(latents, jax.devices("tpu")[0]) - # assert latents.device == latest_upsampler.device with default_env(): - latents = un_normalize_latents( # need to switch this out? + latents = un_normalize_latents( interop.torch_view(latents), self.vae, vae_per_channel_normalize=True ) upsampled_latents = latest_upsampler( @@ -979,29 +1029,21 @@ def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Te ) return upsampled_latents - def __init__(self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler): + def __init__(self, video_pipeline: LTXVideoPipeline): self.video_pipeline = video_pipeline self.vae = video_pipeline.vae - self.latent_upsampler = latent_upsampler - - def __call__(self, height, width, num_frames, is_video, output_type, generator, config) -> Any: + def __call__(self, height, width, num_frames, is_video, output_type, config, enhance_prompt: bool = False) -> Any: + #first pass original_output_type = output_type - original_width = width - original_height = height - x_width = int(width * config.downscale_factor) - downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) - x_height = int(height * config.downscale_factor) - downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) - # #use original height and width here to see output_type = "latent" result = self.video_pipeline( - height=original_height, - width=original_width, + height=height, + width=width, + enhance_prompt=enhance_prompt, num_frames=num_frames, - is_video=True, + is_video=is_video, output_type=output_type, - generator=generator, guidance_scale=config.first_pass["guidance_scale"], stg_scale=config.first_pass["stg_scale"], rescaling_scale=config.first_pass["rescaling_scale"], @@ -1015,25 +1057,23 @@ def __call__(self, height, width, num_frames, is_video, output_type, generator, ) latents = result print("done") - upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) # convert back to pytorch here - ##maybe change this? - latents = torch.from_numpy(np.array(latents)) # .to(torch.device('cpu')) - upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) # .to(torch.device('cpu')) + latent_upsampler = self.load_latent_upsampler(config) + upsampled_latents = self._upsample_latents(latent_upsampler, latents) + # upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) # .to(torch.device('cpu')) upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + + #second pass latents = upsampled_latents output_type = original_output_type - width = downscaled_width * 2 - height = downscaled_height * 2 - # import pdb; pdb.set_trace() - latents = jnp.array(latents) + # latents = jnp.array(latents) result = self.video_pipeline( - height=original_height * 2, - width=original_width * 2, + height=height * 2, + width=width * 2, + enhance_prompt=enhance_prompt, num_frames=num_frames, - is_video=True, + is_video=is_video, output_type=original_output_type, latents=latents, - generator=generator, guidance_scale=config.second_pass["guidance_scale"], stg_scale=config.second_pass["stg_scale"], rescaling_scale=config.second_pass["rescaling_scale"], @@ -1052,7 +1092,7 @@ def __call__(self, height, width, num_frames, is_video, output_type, generator, videos = F.interpolate( videos, - size=(original_height, original_width), + size=(height, width), mode="bilinear", align_corners=False, ) From 072982c5a78ea4f294f7853152578b41c6259ba0 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 22 Jul 2025 02:42:17 +0000 Subject: [PATCH 57/69] added timing --- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 15 +++++++++++++++ .../pipelines/ltx_video/ltx_video_pipeline.py | 1 - 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 5798d1c8a..fced759d3 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -24,7 +24,7 @@ pipeline_type: multi-scale prompt: "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage." height: 512 width: 512 -num_frames: 344 #344 +num_frames: 88 #344 flow_shift: 5.0 fps: 24 downscale_factor: 0.6666666 diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 11d18bf77..cefa661ad 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -11,6 +11,7 @@ import os import torch +import time from pathlib import Path @@ -93,6 +94,7 @@ def run(config): pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) if config.pipeline_type == "multi-scale": pipeline = LTXMultiScalePipeline(pipeline) + s0 = time.perf_counter() images = pipeline( height=height_padded, width=width_padded, @@ -102,6 +104,19 @@ def run(config): config=config, enhance_prompt = False ) + print("compile time: ", (time.perf_counter() - s0)) + s0 = time.perf_counter() + images = pipeline( + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + is_video=True, + output_type="pt", + config=config, + enhance_prompt = False + ) + print("generation time: ", (time.perf_counter() - s0)) + (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom pad_right = -pad_right diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 3e45c8c73..96bc0702c 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -561,7 +561,6 @@ def __call__( skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, ): - import pdb; pdb.set_trace() prompt = self.config.prompt is_video = kwargs.get("is_video", False) if prompt is not None and isinstance(prompt, str): From 8042df0c17e313e4865a2da0681abff3845d9196 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 23 Jul 2025 01:38:48 +0000 Subject: [PATCH 58/69] pipeline cleaned, licence added --- src/maxdiffusion/configs/ltx_video.yml | 10 +- src/maxdiffusion/generate_ltx_video.py | 64 +++-- .../models/ltx_video/autoencoders/__init__.py | 16 ++ .../ltx_video/autoencoders/causal_conv3d.py | 16 ++ .../autoencoders/causal_video_autoencoder.py | 20 +- .../ltx_video/autoencoders/conv_nd_factory.py | 16 ++ .../ltx_video/autoencoders/dual_conv3d.py | 16 ++ .../autoencoders/latent_upsampler.py | 16 ++ .../ltx_video/autoencoders/pixel_norm.py | 16 ++ .../ltx_video/autoencoders/pixel_shuffle.py | 16 ++ .../models/ltx_video/autoencoders/vae.py | 16 ++ .../ltx_video/autoencoders/vae_encode.py | 16 ++ .../ltx_video/autoencoders/vae_torchax.py | 20 +- .../autoencoders/video_autoencoder.py | 18 +- .../models/ltx_video/repeatable_layer.py | 23 +- .../ltx_video/transformers/attention.py | 13 +- .../transformers/symmetric_patchifier.py | 16 ++ .../utils/diffusers_config_mapping.py | 16 ++ .../ltx_video/utils/prompt_enhance_utils.py | 16 ++ .../ltx_video/utils/skip_layer_strategy.py | 16 ++ .../models/ltx_video/utils/torch_utils.py | 16 ++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 1 - .../pipelines/ltx_video/ltx_video_pipeline.py | 263 +++++++----------- 23 files changed, 434 insertions(+), 222 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index fced759d3..f0fc85086 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -8,7 +8,7 @@ activations_dtype: 'bfloat16' run_name: '' -output_dir: 'ltx-video-output' +output_dir: '/mnt/disks/diffusionproj' save_config_to_gcs: False #Checkpoints @@ -21,19 +21,19 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage." +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. " +#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" height: 512 width: 512 -num_frames: 88 #344 +num_frames: 88 flow_shift: 5.0 -fps: 24 downscale_factor: 0.6666666 spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" prompt_enhancement_words_threshold: 120 stg_mode: "attention_values" decode_timestep: 0.05 decode_noise_scale: 0.025 -models_dir: "/mnt/disks/diffusionproj" #where safetensor file is +seed: 10 first_pass: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index cefa661ad..8423a7a8e 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,16 +1,28 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + import numpy as np from absl import app from typing import Sequence from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy -from huggingface_hub import hf_hub_download import imageio from datetime import datetime - import os -import torch import time from pathlib import Path @@ -28,9 +40,6 @@ def calculate_padding( pad_bottom = pad_height - pad_top # Handles odd padding pad_left = pad_width // 2 pad_right = pad_width - pad_left # Handles odd padding - - # Return padded tensor - # Padding format is (left, right, top, bottom) padding = (pad_left, pad_right, pad_top, pad_bottom) return padding @@ -59,8 +68,6 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: return "-".join(result) - - def get_unique_filename( base: str, ext: str, @@ -70,9 +77,7 @@ def get_unique_filename( endswith=None, index_range=1000, ) -> Path: - base_filename = ( - f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}" - ) + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}" for i in range(index_range): filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" if not os.path.exists(filename): @@ -87,13 +92,23 @@ def run(config): padding = calculate_padding(config.height, config.width, height_padded, width_padded) prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold prompt_word_count = len(config.prompt.split()) - enhance_prompt = ( - prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold - ) + enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) - if config.pipeline_type == "multi-scale": + if config.pipeline_type == "multi-scale": pipeline = LTXMultiScalePipeline(pipeline) + # s0 = time.perf_counter() + # images = pipeline( + # height=height_padded, + # width=width_padded, + # num_frames=num_frames_padded, + # is_video=True, + # output_type="pt", + # config=config, + # enhance_prompt=enhance_prompt, + # seed = config.seed + # ) + # print("compile time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() images = pipeline( height=height_padded, @@ -102,21 +117,11 @@ def run(config): is_video=True, output_type="pt", config=config, - enhance_prompt = False - ) - print("compile time: ", (time.perf_counter() - s0)) - s0 = time.perf_counter() - images = pipeline( - height=height_padded, - width=width_padded, - num_frames=num_frames_padded, - is_video=True, - output_type="pt", - config=config, - enhance_prompt = False + enhance_prompt=enhance_prompt, + seed=config.seed, ) print("generation time: ", (time.perf_counter() - s0)) - + (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom pad_right = -pad_right @@ -127,6 +132,7 @@ def run(config): images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py index e69de29bb..285b6e81c 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py index 8797880d3..7206893d0 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Tuple, Union import torch diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py index 439953ed6..dd94dfbb6 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import json import os from functools import partial @@ -218,11 +234,11 @@ def to_json_string(self) -> str: return json.dumps(self.config.__dict__) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - if any([key.startswith("vae.") for key in state_dict.keys()]): + if any([key.startswith("vae.") for key in state_dict.keys()]): # noqa: C419 state_dict = {key.replace("vae.", ""): value for key, value in state_dict.items() if key.startswith("vae.")} ckpt_state_dict = {key: value for key, value in state_dict.items() if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)} - model_keys = set(name for name, _ in self.named_modules()) + model_keys = set(name for name, _ in self.named_modules()) # noqa: C401 key_mapping = { ".resnets.": ".res_blocks.", diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py index dd4aba0ca..d0be897e8 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Tuple, Union import torch diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py index 1487a539d..c0a4db3eb 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import math from typing import Tuple, Union diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py index 3133b8a49..56a6c2d1b 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Union from pathlib import Path import os diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py index 0d4277e06..422df50f5 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import torch from torch import nn diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py index dae539d26..7bd4761c4 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import torch.nn as nn from einops import rearrange diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py index 1a05b675c..17823c5d0 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Union import torch diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py index e24f16fd2..d4c44024e 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Tuple import torch from diffusers import AutoencoderKL diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py index dfb7512ff..ea46b4fbd 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_torchax.py @@ -1,12 +1,26 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from maxdiffusion.models.ltx_video.autoencoders import causal_conv3d from maxdiffusion.models.ltx_video.autoencoders.vae_encode import vae_encode, vae_decode import jax from torchax import interop -import os from torchax import default_env -import jax.numpy as jnp # remove weight attribute to avoid error in JittableModule # in the future, this will be fixed in ltxv public repo @@ -16,7 +30,7 @@ class TorchaxCausalVideoAutoencoder(interop.JittableModule): def __init__(self, vae: CausalVideoAutoencoder): - super().__init__(vae, extra_jit_args=dict(static_argnames=["split_size", "vae_per_channel_normalize"])) + super().__init__(vae, extra_jit_args=dict(static_argnames=["split_size", "vae_per_channel_normalize"])) # noqa: C408 def encode(self, media_items: jax.Array, split_size: int = 1, vae_per_channel_normalize: bool = True) -> jax.Array: if media_items.ndim != 5: diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py index 77ae60d4e..8e6b67a43 100644 --- a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import json import os from functools import partial @@ -135,7 +151,7 @@ def to_json_string(self) -> str: return json.dumps(self.config.__dict__) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - model_keys = set(name for name, _ in self.named_parameters()) + model_keys = set(name for name, _ in self.named_parameters()) # noqa: C401 key_mapping = { ".resnets.": ".res_blocks.", diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 31c6b5b15..8f6e43dc0 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional @@ -8,13 +24,6 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format - - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ - module: Callable[[Any], nn.Module] module_init_args: List[Any] module_init_kwargs: Dict[str, Any] diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index a6ac74c71..6fad32d8e 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -458,7 +458,6 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") @@ -476,9 +475,7 @@ def __call__( batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape( - skip_layer_mask[block_index], (batch_size, 1, 1) - ) # here skip_layer_mask is (48,3), changed this currently! + skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) query = self.to_q(hidden_states) query = self.q_norm(query) @@ -645,14 +642,6 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): if q_segment_ids is not None and q_segment_ids.ndim != 2: raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # ("data", "fsdp", "fsdp_transpose", "expert"), - # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - # None, - # None, - # ) - # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") qkvo_sharding_spec = jax.sharding.PartitionSpec( "data", "fsdp", diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py index 83ec62342..d53b4d7ca 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from abc import ABC, abstractmethod from typing import Tuple diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py index 832bda051..81094d676 100644 --- a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main def make_hashable_key(dict_key): def convert_value(value): if isinstance(value, list): diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py index 20d09f73e..1d404be39 100644 --- a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import logging from typing import Union, List, Optional diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py index 476d38c75..74e74c1c6 100644 --- a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py index 32b19b167..6dca31b1f 100644 --- a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main import torch from torch import nn diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index bce38fb20..75b16b011 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,5 +1,4 @@ { - "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 96bc0702c..0d9533dff 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -13,7 +13,6 @@ # limitations under the License. import math import os -from xmlrpc.client import Boolean from jax import Array from typing import Optional, List, Union, Tuple from einops import rearrange @@ -28,7 +27,6 @@ from transformers import ( AutoModelForCausalLM, AutoProcessor, - AutoTokenizer, ) from huggingface_hub import hf_hub_download from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( @@ -42,16 +40,12 @@ ) from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt -from math import e from types import NoneType from typing import Any, Dict -import numpy as np import jax import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P -from typing import Optional, Union, List -import torch +from jax.sharding import Mesh from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy @@ -59,22 +53,18 @@ from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import os -import json import functools import orbax.checkpoint as ocp -import pickle -def validate_transformer_inputs( - prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids -): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids): + # Note: reference shape annotated for first pass default inference parameters + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32 + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32 + print("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32 + print( + "encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype + ) # (3, 256) int32 class LTXVideoPipeline: @@ -137,7 +127,6 @@ def load_transformer(cls, config): config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] ignored_keys = [ "_class_name", @@ -145,7 +134,6 @@ def load_transformer(cls, config): "_name_or_path", "causal_temporal_positioning", "in_channels", - "ckpt_path", ] in_channels = model_config["in_channels"] for name in ignored_keys: @@ -154,12 +142,15 @@ def load_transformer(cls, config): transformer = Transformer3DModel( **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh ) + key = jax.random.PRNGKey(config.seed) + key, subkey = jax.random.split(key) weights_init_fn = functools.partial( - transformer.init_weights, in_channels, jax.random.PRNGKey(0), model_config["caption_channels"], eval_only=True + transformer.init_weights, in_channels, subkey, model_config["caption_channels"], eval_only=True ) # loading from weight checkpoint - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + models_dir = config.output_dir + jax_weights_dir = os.path.join(models_dir, "jax_weights") + checkpoint_manager = ocp.CheckpointManager(jax_weights_dir) transformer_state, transformer_state_shardings = setup_initial_state( model=transformer, tx=None, @@ -224,7 +215,7 @@ def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) - models_dir = config.models_dir + models_dir = config.output_dir ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" if not os.path.isfile(ltxv_model_name_or_path): ltxv_model_path = hf_hub_download( @@ -343,7 +334,7 @@ def retrieve_timesteps( # currently doesn't support custom timesteps return scheduler_state - def encode_prompt( #input changed + def encode_prompt( # currently only supports passing in a prompt self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, @@ -356,14 +347,10 @@ def encode_prompt( #input changed batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] max_length = text_encoder_max_tokens # TPU supports only lengths multiple of 128 - - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," + + assert self.text_encoder is not None, "You should provide either prompt_embeds or self.text_encoder should not be None," prompt = self._text_preprocessing(prompt) text_inputs = self.tokenizer( prompt, @@ -374,10 +361,6 @@ def encode_prompt( #input changed return_tensors="pt", ) text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) prompt_attention_mask = jnp.array(text_inputs.attention_mask) prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) @@ -463,13 +446,12 @@ def prepare_latents( # currently no support for media item encoding, since enco latents = noise else: # Noise the latents to the required (first) timestep - # Ensure timestep is a jnp.array with the correct dtype timestep_array = jnp.array(timestep, dtype=dtype) latents = timestep_array * noise + (1 - timestep_array) * latents return latents - def prepare_conditioning( # no support for conditioning items, conditioning mask, needs to conver to torch before patchifier + def prepare_conditioning( # no support for conditioning items, conditioning mask, needs to convert to torch before patchifier self, init_latents: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: @@ -483,10 +465,9 @@ def prepare_conditioning( # no support for conditioning items, conditioning mas 0, ) - def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: r""" - Borrowed from diffusers.image_processor + Borrowed from diffusers.image_processor Denormalize an image array to [0,1]. Args: @@ -516,12 +497,11 @@ def _denormalize_conditionally(self, images: torch.Tensor, do_denormalize: Optio return torch.stack([self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]) - # same as diffusers image_processor.postprocess - - def postprocess_to_output_type(self, image, output_type): - ''' + def postprocess_to_output_type(self, image, output_type): + """ + Borrowed from diffusers.image_processor Currrently supporting output type latent and pt - ''' + """ if not isinstance(image, torch.Tensor): raise ValueError(f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor") @@ -530,7 +510,7 @@ def postprocess_to_output_type(self, image, output_type): if output_type == "latent": return image - image = self._denormalize_conditionally(image, None) + image = self._denormalize_conditionally(image, None) if output_type == "pt": return image @@ -557,18 +537,18 @@ def __call__( skip_initial_inference_steps: int = 0, skip_final_inference_steps: int = 0, cfg_star_rescale: bool = False, + seed: int = 0, skip_layer_strategy: Optional[SkipLayerStrategy] = None, skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, ): + key = jax.random.PRNGKey(seed) prompt = self.config.prompt is_video = kwargs.get("is_video", False) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor @@ -597,7 +577,7 @@ def __call__( skip_final_inference_steps, ) - #set up guidance + # set up guidance guidance_mapping = [] if guidance_timesteps: @@ -631,8 +611,7 @@ def __call__( if do_spatio_temporal_guidance: num_conds += 1 - - #prepare skip block list + # prepare skip block list is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) if not is_list_of_lists: @@ -650,7 +629,7 @@ def __call__( self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) for skip_blocks in skip_block_list ] - + if enhance_prompt: prompt = generate_cinematic_prompt( self.prompt_enhancer_image_caption_model, @@ -676,7 +655,7 @@ def __call__( ) prompt_embeds_batch = prompt_embeds prompt_attention_mask_batch = prompt_attention_mask - + if do_classifier_free_guidance: prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) @@ -689,14 +668,14 @@ def __call__( ], axis=0, ) - - #optionally pass in a latent here + + # optionally pass in a latent here latents = self.prepare_latents( latents=latents, timestep=scheduler_state.timesteps[0], latent_shape=latent_shape, dtype=jnp.float32, - key=jax.random.PRNGKey(0), + key=key, ) latents, pixel_coords, num_cond_latents = self.prepare_conditioning( @@ -706,9 +685,7 @@ def __call__( pixel_coords = jnp.concatenate([pixel_coords] * num_conds, axis=0) fractional_coords = pixel_coords.astype(jnp.float32) fractional_coords = fractional_coords.at[:, 0].set(fractional_coords[:, 0] * (1.0 / frame_rate)) - - # initialize dummy noise - noise_cond = jnp.ones((1, 1)) + validate_transformer_inputs(prompt_embeds_batch, fractional_coords, latents, prompt_attention_mask_batch) p_run_inference = functools.partial( run_inference, @@ -737,7 +714,7 @@ def __call__( latents, scheduler_state = p_run_inference( transformer_state=self.transformer_state, latents=latents, scheduler_state=scheduler_state ) - + latents = latents[:, num_cond_latents:] latents = self.patchifier.unpatchify( @@ -749,8 +726,7 @@ def __call__( if output_type != "latent": if self.vae.decoder.timestep_conditioning: - noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) - + noise = jax.random.normal(key, latents.shape, dtype=latents.dtype) # Convert decode_timestep to a list if it's not already one if not isinstance(decode_timestep, (list, jnp.ndarray)): decode_timestep = [decode_timestep] * latents.shape[0] @@ -761,7 +737,6 @@ def __call__( elif not isinstance(decode_noise_scale, (list, jnp.ndarray)): decode_noise_scale = [decode_noise_scale] * latents.shape[0] - # Convert lists to JAX arrays decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32) # Reshape decode_noise_scale for broadcasting @@ -778,7 +753,7 @@ def __call__( vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), timestep=decode_timestep, ) - # convert back to torch to post process using the diffusers library + # convert back to torch to postprocess using the diffusers library image = self.postprocess_to_output_type( torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type ) @@ -786,16 +761,13 @@ def __call__( else: image = latents - # Offload all models - if not return_dict: return (image,) return image - -def transformer_forward_pass( +def transformer_forward_pass( latents, state, noise_cond, @@ -818,7 +790,7 @@ def transformer_forward_pass( encoder_attention_segment_ids=encoder_attention_segment_ids, skip_layer_mask=skip_layer_mask, skip_layer_strategy=skip_layer_strategy, - ) + ) return noise_pred, state @@ -862,10 +834,10 @@ def run_inference( elif current_timestep.ndim == 0: current_timestep = jnp.expand_dims(current_timestep, axis=0) - # Broadcast to batch dimension (compatible with ONNX/Core ML) + # Broadcast to batch dimension current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): noise_pred, transformer_state = transformer_forward_pass( latent_model_input, transformer_state, @@ -878,8 +850,8 @@ def run_inference( skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), skip_layer_strategy=skip_layer_strategy, ) - - #perform guidance on noise prediction + + # perform guidance on noise prediction if do_spatio_temporal_guidance: chunks = jnp.split(noise_pred, num_conds, axis=0) noise_pred_text = chunks[-2] @@ -892,11 +864,11 @@ def run_inference( if cfg_star_rescale: positive_flat = noise_pred_text.reshape(batch_size, -1) negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) - squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 - alpha = dot_product / squared_norm + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm alpha = alpha.reshape(batch_size, 1, 1) - noise_pred_uncond = alpha * noise_pred_uncond + noise_pred_uncond = alpha * noise_pred_uncond noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) elif do_spatio_temporal_guidance: noise_pred = noise_pred_text @@ -911,87 +883,58 @@ def run_inference( factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] + current_timestep = current_timestep[:1] latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() return latents, scheduler_state +def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, factor: float = 1.0) -> jnp.ndarray: + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor, implemented in JAX. + + Args: + latents (jax.Array): Input latents to normalize. Expected shape (B, C, F, H, W). + reference_latents (jax.Array): The reference latents providing style statistics. + Expected shape (B, C, F, H, W). + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + jax.Array: The transformed latent tensor. + """ + with default_env(): + latents = jax.device_put(latents, jax.devices("tpu")[0]) + reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) + # Define the core AdaIN operation for a single (F, H, W) slice. + # This function will be vmapped over batch (B) and channel (C) dimensions. + def _adain_single_slice(latent_slice: jnp.ndarray, ref_latent_slice: jnp.ndarray) -> jnp.ndarray: + """ + Applies AdaIN to a single latent slice (F, H, W) based on a reference slice. + """ + r_mean = jnp.mean(ref_latent_slice) + r_sd = jnp.std(ref_latent_slice) -#up to here -# def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): -# """ -# Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on -# statistics from a reference latent tensor. - -# Args: -# latent (torch.Tensor): Input latents to normalize -# reference_latent (torch.Tensor): The reference latents providing style statistics. -# factor (float): Blending factor between original and transformed latent. -# Range: -10.0 to 10.0, Default: 1.0 - -# Returns: -# torch.Tensor: The transformed latent tensor -# """ -# with default_env(): -# result = latents.clone() - -# for i in range(latents.size(0)): -# for c in range(latents.size(1)): -# r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order -# i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + # Calculate standard deviation and mean for the input latent slice + i_mean = jnp.mean(latent_slice) + i_sd = jnp.std(latent_slice) + i_sd_safe = jnp.where(i_sd < 1e-6, 1.0, i_sd) + normalized_latent = (latent_slice - i_mean) / i_sd_safe + transformed_latent_slice = normalized_latent * r_sd + r_mean + return transformed_latent_slice -# result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + transformed_latents_core = jax.vmap( + jax.vmap(_adain_single_slice, in_axes=(0, 0)), in_axes=(0, 0) # Vmap over batch (axis 0) + )(latents, reference_latents) + result_blended = latents * (1.0 - factor) + transformed_latents_core * factor -# result = torch.lerp(latents, result, factor) -# return result + return result_blended -def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, factor: float = 1.0) -> jnp.ndarray: - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor, implemented in JAX. - Args: - latents (jax.Array): Input latents to normalize. Expected shape (B, C, F, H, W). - reference_latents (jax.Array): The reference latents providing style statistics. - Expected shape (B, C, F, H, W). - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 +class LTXMultiScalePipeline: - Returns: - jax.Array: The transformed latent tensor. - """ - with default_env(): - latents = jax.device_put(latents, jax.devices("tpu")[0]) - reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) - # Define the core AdaIN operation for a single (F, H, W) slice. - # This function will be vmapped over batch (B) and channel (C) dimensions. - def _adain_single_slice(latent_slice: jnp.ndarray, ref_latent_slice: jnp.ndarray) -> jnp.ndarray: - """ - Applies AdaIN to a single latent slice (F, H, W) based on a reference slice. - """ - r_mean = jnp.mean(ref_latent_slice) - r_sd = jnp.std(ref_latent_slice) - - # Calculate standard deviation and mean for the input latent slice - i_mean = jnp.mean(latent_slice) - i_sd = jnp.std(latent_slice) - i_sd_safe = jnp.where(i_sd < 1e-6, 1.0, i_sd) - normalized_latent = (latent_slice - i_mean) / i_sd_safe - transformed_latent_slice = normalized_latent * r_sd + r_mean - return transformed_latent_slice - - transformed_latents_core = jax.vmap( - jax.vmap(_adain_single_slice, in_axes=(0, 0)), - in_axes=(0, 0) # Vmap over batch (axis 0) - )(latents, reference_latents) - result_blended = latents * (1.0 - factor) + transformed_latents_core * factor - - return result_blended - - -class LTXMultiScalePipeline: @classmethod def load_latent_upsampler(cls, config): spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path @@ -1000,7 +943,7 @@ def load_latent_upsampler(cls, config): spatial_upscaler_model_path = hf_hub_download( repo_id="Lightricks/LTX-Video", filename=spatial_upscaler_model_name_or_path, - local_dir=config.models_dir, + local_dir=config.output_dir, repo_type="model", ) else: @@ -1013,16 +956,11 @@ def load_latent_upsampler(cls, config): latent_upsampler.eval() return latent_upsampler - def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: jnp.ndarray): latents = jax.device_put(latents, jax.devices("tpu")[0]) with default_env(): - latents = un_normalize_latents( - interop.torch_view(latents), self.vae, vae_per_channel_normalize=True - ) - upsampled_latents = latest_upsampler( - torch.from_numpy(np.array(latents)) - ) # here converted back to torch, cause upsampler in pytorch + latents = un_normalize_latents(interop.torch_view(latents), self.vae, vae_per_channel_normalize=True) + upsampled_latents = latest_upsampler(torch.from_numpy(np.array(latents))) # converted back to torch before upsampler upsampled_latents = normalize_latents( interop.torch_view(jnp.array(upsampled_latents.detach().numpy())), self.vae, vae_per_channel_normalize=True ) @@ -1032,8 +970,10 @@ def __init__(self, video_pipeline: LTXVideoPipeline): self.video_pipeline = video_pipeline self.vae = video_pipeline.vae - def __call__(self, height, width, num_frames, is_video, output_type, config, enhance_prompt: bool = False) -> Any: - #first pass + def __call__( + self, height, width, num_frames, is_video, output_type, config, seed: int = 0, enhance_prompt: bool = False + ) -> Any: + # first pass original_output_type = output_type output_type = "latent" result = self.video_pipeline( @@ -1043,6 +983,7 @@ def __call__(self, height, width, num_frames, is_video, output_type, config, enh num_frames=num_frames, is_video=is_video, output_type=output_type, + seed=seed, guidance_scale=config.first_pass["guidance_scale"], stg_scale=config.first_pass["stg_scale"], rescaling_scale=config.first_pass["rescaling_scale"], @@ -1055,22 +996,20 @@ def __call__(self, height, width, num_frames, is_video, output_type, config, enh skip_block_list=config.first_pass["skip_block_list"], ) latents = result - print("done") + print("first pass done") latent_upsampler = self.load_latent_upsampler(config) upsampled_latents = self._upsample_latents(latent_upsampler, latents) - # upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) # .to(torch.device('cpu')) upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) - - #second pass + + # second pass latents = upsampled_latents - output_type = original_output_type - # latents = jnp.array(latents) result = self.video_pipeline( height=height * 2, width=width * 2, enhance_prompt=enhance_prompt, num_frames=num_frames, is_video=is_video, + seed=seed, output_type=original_output_type, latents=latents, guidance_scale=config.second_pass["guidance_scale"], From 0c485245fde6326aabc9526d18a56c8cb6f5c366 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 23 Jul 2025 01:52:55 +0000 Subject: [PATCH 59/69] changed output to cmd line --- src/maxdiffusion/configs/ltx_video.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index f0fc85086..5ff699bad 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -8,7 +8,7 @@ activations_dtype: 'bfloat16' run_name: '' -output_dir: '/mnt/disks/diffusionproj' +output_dir: '' save_config_to_gcs: False #Checkpoints From d4c6738f91b1cbe3500b1e8b151bc16dfb89924f Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 23 Jul 2025 17:23:36 +0000 Subject: [PATCH 60/69] added init file --- src/maxdiffusion/pipelines/ltx_video/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/maxdiffusion/pipelines/ltx_video/__init__.py b/src/maxdiffusion/pipelines/ltx_video/__init__.py index e69de29bb..0a2669d7a 100644 --- a/src/maxdiffusion/pipelines/ltx_video/__init__.py +++ b/src/maxdiffusion/pipelines/ltx_video/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 36242d272dedc27b860f0a1493040bb7bb3e7d45 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 23 Jul 2025 18:48:29 +0000 Subject: [PATCH 61/69] changed input format --- .../utils/conversion_script_instruction.md | 11 +----- .../utils/convert_torch_weights_to_jax.py | 39 ++++++------------- 2 files changed, 14 insertions(+), 36 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md index a6ca08835..0f316889a 100644 --- a/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md +++ b/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md @@ -1,10 +1,3 @@ ### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions: -1. Weight Downloading and Conversion - - If first time running (no local safetensors): \ - In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json. - - If already have local pytorch checkpoint: \ - Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location -2. Restoring Jax Weights into transformer: - - Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path. - - Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder. - +In the folder src/maxdiffusion/models/ltx_video/utils, run: +python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index 82ff03bab..151f6a3b9 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -230,18 +230,16 @@ def main(args): weight_file = "ltxv-13b-0.9.7-dev.safetensors" # download from huggingface, otherwise load from local - if args.local_ckpt_path is None: - print("Loading from HF", flush=True) - model_name = "Lightricks/LTX-Video" - local_file_path = hf_hub_download( - repo_id=model_name, - filename=weight_file, - local_dir=args.download_ckpt_path, - local_dir_use_symlinks=False, - ) - else: - base_dir = args.local_ckpt_path - local_file_path = os.path.join(base_dir, weight_file) + + print("Loading from HF", flush=True) + model_name = "Lightricks/LTX-Video" + absolute_ckpt_path = os.path.abspath(args.ckpt_path) + local_file_path = hf_hub_download( + repo_id=model_name, + filename=weight_file, + local_dir=absolute_ckpt_path, + local_dir_use_symlinks=False, + ) torch_state_dict = load_file(local_file_path) print("Initializing pytorch transformer..", flush=True) @@ -284,7 +282,7 @@ def main(args): params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) print("Creating checkpointer and jax state for saving..", flush=True) - relative_ckpt_path = args.output_dir + relative_ckpt_path = os.path.join(args.ckpt_path, "jax_weights") absolute_ckpt_path = os.path.abspath(relative_ckpt_path) tx = optax.adamw(learning_rate=1e-5) with jax.default_device("cpu"): @@ -303,25 +301,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") parser.add_argument( - "--local_ckpt_path", + "--ckpt_path", type=str, required=False, help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", ) - parser.add_argument( - "--download_ckpt_path", - type=str, - required=False, - help="Location to download safetensors from huggingface", - ) - - parser.add_argument( - "--output_dir", - type=str, - required=True, - help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'", - ) parser.add_argument( "--output_step_num", default=1, From 774e2c4ae06f805ded7a9c2b8293d52e1cb80af8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 24 Jul 2025 21:49:11 +0000 Subject: [PATCH 62/69] updated requirements --- requirements.txt | 4 +++- setup.sh | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index eeaf2c9e3..eeb010951 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -jax==0.5.3 +jax==0.6.2 jaxlib>=0.4.30 grain-nightly==0.0.10 google-cloud-storage==2.17.0 @@ -24,6 +24,8 @@ tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 git+https://github.com/mlperf/logging.git +git+https://github.com/Lightricks/LTX-Video +git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.3 tokenizers==0.21.0 diff --git a/setup.sh b/setup.sh index fc3f640e0..9beb11d23 100644 --- a/setup.sh +++ b/setup.sh @@ -110,4 +110,4 @@ else fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 +pip3 install -e . || echo "Failed to install maxdiffusion" >&2 From f23eeef85dea3cbb3cbeef39c76a5a47f9c6b944 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 24 Jul 2025 22:21:31 +0000 Subject: [PATCH 63/69] merged in conversion script --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 9beb11d23..fc3f640e0 100644 --- a/setup.sh +++ b/setup.sh @@ -110,4 +110,4 @@ else fi # Install maxdiffusion -pip3 install -e . || echo "Failed to install maxdiffusion" >&2 +pip3 install -U . || echo "Failed to install maxdiffusion" >&2 From c18c0c66d1b27dd40391d97beece7d674d441fb1 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 25 Jul 2025 08:08:38 +0000 Subject: [PATCH 64/69] fixed importing error --- setup.sh | 2 +- src/maxdiffusion/configs/ltx_video.yml | 1 + src/maxdiffusion/generate_ltx_video.py | 12 - .../transformers_pytorch/__init__.py | 0 .../transformers_pytorch/attention.py | 1264 +++++++++++++++++ .../transformers_pytorch/embeddings.py | 129 ++ .../symmetric_patchifier.py | 84 ++ .../transformers_pytorch/transformer3d.py | 507 +++++++ .../utils/convert_torch_weights_to_jax.py | 17 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 12 +- 10 files changed, 2000 insertions(+), 28 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py diff --git a/setup.sh b/setup.sh index fc3f640e0..e91d44156 100644 --- a/setup.sh +++ b/setup.sh @@ -110,4 +110,4 @@ else fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 +pip3 install -U . || echo "Failed to install maxdiffusion" >&2 \ No newline at end of file diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 5ff699bad..07e101a4d 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -9,6 +9,7 @@ activations_dtype: 'bfloat16' run_name: '' output_dir: '' +config_path: '' save_config_to_gcs: False #Checkpoints diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 8423a7a8e..cf1154081 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -97,18 +97,6 @@ def run(config): pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) if config.pipeline_type == "multi-scale": pipeline = LTXMultiScalePipeline(pipeline) - # s0 = time.perf_counter() - # images = pipeline( - # height=height_padded, - # width=width_padded, - # num_frames=num_frames_padded, - # is_video=True, - # output_type="pt", - # config=config, - # enhance_prompt=enhance_prompt, - # seed = config.seed - # ) - # print("compile time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() images = pipeline( height=height_padded, diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py new file mode 100644 index 000000000..bee0839ad --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -0,0 +1,1264 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = ( + nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer( + dim, norm_eps, norm_elementwise_affine + ) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter( + torch.randn(num_ada_params, dim) / dim**0.5 + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze( + 1 + ) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.TransformerBlock + ): + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * ( + 1.0 - skip_layer_mask + ) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if ( + attention_mask is not None + ): # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones( + batch_size, query.shape[2], device=query.device, dtype=torch.float32 + ) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert ( + query.shape[2] % 128 == 0 + ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert ( + key.shape[2] % 128 == 0 + ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionSkip + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py new file mode 100644 index 000000000..a30d6be16 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -0,0 +1,129 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py new file mode 100644 index 000000000..2eca32033 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py new file mode 100644 index 000000000..8c1d1991d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py @@ -0,0 +1,507 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from ltx_video.models.transformers.attention import BasicTransformerBlock +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = ( + use_tpu_flash_attention # FIXME: push config down to the attention modules + ) + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" + ) + if positional_embedding_max_pos is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=False + ) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // 6, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): #noqa: C419 + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index 151f6a3b9..45f1d3b0b 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -29,6 +29,7 @@ from urllib.parse import urljoin from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d import Transformer3DModel as Transformer3DModel from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax from huggingface_hub import hf_hub_download @@ -217,14 +218,14 @@ def main(args): "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " "training loss when resuming from the converted checkpoint." ) - print("Downloading files from GitHub...") - github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" - ltx_repo_path = "../" - target_folder = "transformers_pytorch" - files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] - module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" - - Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) + # print("Downloading files from GitHub...") + # github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" + # ltx_repo_path = "../" + # target_folder = "transformers_pytorch" + # files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] + # module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" + + # Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) print("Loading safetensors, flush = True") weight_file = "ltxv-13b-0.9.7-dev.safetensors" diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 0d9533dff..e5f417bed 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -123,9 +123,7 @@ def load_scheduler(cls, ckpt_path, config): def load_transformer(cls, config): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: + with open(config.config_path, "r") as f: model_config = json.load(f) ignored_keys = [ @@ -516,6 +514,7 @@ def postprocess_to_output_type(self, image, output_type): def __call__( self, + config, height: int, width: int, num_frames: int, @@ -555,11 +554,8 @@ def __call__( latent_num_frames = num_frames // self.video_scale_factor if isinstance(self.vae, TorchaxCausalVideoAutoencoder) and is_video: latent_num_frames += 1 - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: + with open(config.config_path, "r") as f: model_config = json.load(f) - latent_shape = ( batch_size * num_images_per_prompt, model_config["in_channels"], @@ -977,6 +973,7 @@ def __call__( original_output_type = output_type output_type = "latent" result = self.video_pipeline( + config=config, height=height, width=width, enhance_prompt=enhance_prompt, @@ -1004,6 +1001,7 @@ def __call__( # second pass latents = upsampled_latents result = self.video_pipeline( + config=config, height=height * 2, width=width * 2, enhance_prompt=enhance_prompt, From cfe2c642cfe09aee474d7f1fb1ce1cc4620f22e3 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 25 Jul 2025 17:42:51 +0000 Subject: [PATCH 65/69] fixed importing issue --- .../ltx_video/transformers_pytorch/__init__.py | 16 ++++++++++++++++ .../ltx_video/transformers_pytorch/embeddings.py | 16 ++++++++++++++++ .../transformers_pytorch/symmetric_patchifier.py | 16 ++++++++++++++++ .../transformers_pytorch/transformer3d.py | 16 ++++++++++++++++ .../utils/convert_torch_weights_to_jax.py | 8 -------- 5 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py index e69de29bb..285b6e81c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py index a30d6be16..6461039fb 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py import math diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py index 2eca32033..b34df6ed3 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from abc import ABC, abstractmethod from typing import Tuple diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py index 8c1d1991d..74366830c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py import math from dataclasses import dataclass diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index 45f1d3b0b..84b416d52 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -218,14 +218,6 @@ def main(args): "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " "training loss when resuming from the converted checkpoint." ) - # print("Downloading files from GitHub...") - # github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" - # ltx_repo_path = "../" - # target_folder = "transformers_pytorch" - # files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] - # module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" - - # Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) print("Loading safetensors, flush = True") weight_file = "ltxv-13b-0.9.7-dev.safetensors" From 740d4035596e7718cf1e6acdb27d23dfb03e741b Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 25 Jul 2025 18:23:29 +0000 Subject: [PATCH 66/69] merged from main --- src/maxdiffusion/configs/ltx_video.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 07e101a4d..fce674f2c 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -83,10 +83,9 @@ ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 - - +allow_split_physical_axes: False learning_rate_schedule_steps: -1 -max_train_steps: 500 #TODO: change this +max_train_steps: 500 pretrained_model_name_or_path: '' unet_checkpoint: '' dataset_name: 'diffusers/pokemon-gpt4-captions' From e34d47e5b56bf5e1cf87dbba0e9f0cab7d35b2f1 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:24:55 -0700 Subject: [PATCH 67/69] Delete myenv directory --- myenv/bin/python | 1 - myenv/bin/python3 | 1 - myenv/bin/python3.10 | 1 - myenv/lib64 | 1 - myenv/pyvenv.cfg | 3 --- 5 files changed, 7 deletions(-) delete mode 120000 myenv/bin/python delete mode 120000 myenv/bin/python3 delete mode 120000 myenv/bin/python3.10 delete mode 120000 myenv/lib64 delete mode 100644 myenv/pyvenv.cfg diff --git a/myenv/bin/python b/myenv/bin/python deleted file mode 120000 index acd4152a9..000000000 --- a/myenv/bin/python +++ /dev/null @@ -1 +0,0 @@ -/usr/bin/python \ No newline at end of file diff --git a/myenv/bin/python3 b/myenv/bin/python3 deleted file mode 120000 index d8654aa0e..000000000 --- a/myenv/bin/python3 +++ /dev/null @@ -1 +0,0 @@ -python \ No newline at end of file diff --git a/myenv/bin/python3.10 b/myenv/bin/python3.10 deleted file mode 120000 index d8654aa0e..000000000 --- a/myenv/bin/python3.10 +++ /dev/null @@ -1 +0,0 @@ -python \ No newline at end of file diff --git a/myenv/lib64 b/myenv/lib64 deleted file mode 120000 index 7951405f8..000000000 --- a/myenv/lib64 +++ /dev/null @@ -1 +0,0 @@ -lib \ No newline at end of file diff --git a/myenv/pyvenv.cfg b/myenv/pyvenv.cfg deleted file mode 100644 index 266b9070b..000000000 --- a/myenv/pyvenv.cfg +++ /dev/null @@ -1,3 +0,0 @@ -home = /usr/bin -include-system-site-packages = false -version = 3.10.6 From 460f7dba76fd840655482f7bd51dd033ed511285 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 25 Jul 2025 21:21:20 +0000 Subject: [PATCH 68/69] changed ckpt name --- requirements.txt | 1 - .../checkpointing/checkpointing_utils.py | 2 +- src/maxdiffusion/generate_ltx_video.py | 5 +- src/maxdiffusion/max_utils.py | 2 +- src/maxdiffusion/models/ltx_video/linear.py | 2 - .../models/ltx_video/transformers/adaln.py | 1 - .../transformers_pytorch/attention.py | 2233 ++++++++--------- .../transformers_pytorch/embeddings.py | 168 +- .../symmetric_patchifier.py | 130 +- .../transformers_pytorch/transformer3d.py | 873 +++---- .../utils/convert_torch_weights_to_jax.py | 65 - .../pipelines/ltx_video/ltx_video_pipeline.py | 13 +- .../schedulers/scheduling_rectified_flow.py | 3 + 13 files changed, 1624 insertions(+), 1874 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7be2f4aab..2ccbd88ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 -git+https://github.com/mlperf/logging.git git+https://github.com/Lightricks/LTX-Video git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax opencv-python-headless==4.10.0.84 diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b83e85a87..1a8d12ef0 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,7 +213,7 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - if checkpoint_item == " ": + if checkpoint_item == "ltxvid_transformer": return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) else: item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index cf1154081..553d6373e 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -19,7 +19,7 @@ from typing import Sequence from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline -from maxdiffusion import pyconfig +from maxdiffusion import pyconfig, max_logging import imageio from datetime import datetime import os @@ -108,7 +108,7 @@ def run(config): enhance_prompt=enhance_prompt, seed=config.seed, ) - print("generation time: ", (time.perf_counter() - s0)) + max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom @@ -146,7 +146,6 @@ def run(config): resolution=(height, width, config.num_frames), dir=output_dir, ) - print(output_filename) # Write video with imageio.get_writer(output_filename, fps=fps) as video: for frame in video_np: diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 0abeee938..96b60426d 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -405,7 +405,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - if checkpoint_item == " ": + if checkpoint_item == "ltxvid_transformer": state = state else: state = state[checkpoint_item] diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index 3503ab3b4..247e9da1f 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -95,8 +95,6 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - # kernel_in_axis = np.arange(len(axis)) - # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) kernel = self.param( "kernel", nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index e9b287649..1078f0848 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -126,7 +126,6 @@ def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - """ """ embedding_dim: int size_emb_dim: int diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index bee0839ad..a598114ad 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -23,11 +23,11 @@ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy try: - from torch_xla.experimental.custom_kernel import flash_attention + from torch_xla.experimental.custom_kernel import flash_attention except ImportError: - # workaround for automatic tests. Currently this function is manually patched - # to the torch_xla lib on setup of container - pass + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py @@ -36,1229 +36,1102 @@ @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=(cross_attention_dim if not double_self_attention else None), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter(torch.randn(num_ada_params, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - qk_norm (`str`, *optional*, defaults to None): - Set to 'layer_norm' or `rms_norm` to perform query and key normalization. - adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): - The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". - standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): - The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze(1) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.TransformerBlock: + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (1.0 - skip_layer_mask) + + return hidden_states - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' - standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' - norm_eps: float = 1e-5, - qk_norm: Optional[str] = None, - final_dropout: bool = False, - attention_type: str = "default", # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - use_tpu_flash_attention: bool = False, - use_rope: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - self.use_tpu_flash_attention = use_tpu_flash_attention - self.adaptive_norm = adaptive_norm - assert standardization_norm in ["layer_norm", "rms_norm"] - assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError(f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'") + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True - make_norm_layer = ( - nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm - ) + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.norm1 = make_norm_layer( - dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps - ) + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - use_tpu_flash_attention=use_tpu_flash_attention, - qk_norm=qk_norm, - use_rope=use_rope, - ) + self.processor = processor - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=( - cross_attention_dim if not double_self_attention else None - ), - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - use_tpu_flash_attention=use_tpu_flash_attention, - qk_norm=qk_norm, - use_rope=use_rope, - ) # is self-attn if encoder_hidden_states is none - - if adaptive_norm == "none": - self.attn2_norm = make_norm_layer( - dim, norm_eps, norm_elementwise_affine - ) - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) - - # 3. Feed-forward - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. - # 5. Scale-shift for PixArt-Alpha. - if adaptive_norm != "none": - num_ada_params = 4 if adaptive_norm == "single_scale" else 6 - self.scale_shift_table = nn.Parameter( - torch.randn(num_ada_params, dim) / dim**0.5 - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_use_tpu_flash_attention(self): - r""" - Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU - attention kernel. - """ - self.use_tpu_flash_attention = True - self.attn1.set_use_tpu_flash_attention() - self.attn2.set_use_tpu_flash_attention() - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None for name, module in self.named_modules() if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError(f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}") + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( self, - hidden_states: torch.FloatTensor, - freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - skip_layer_mask: Optional[torch.Tensor] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> torch.FloatTensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." - ) - - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - original_hidden_states = hidden_states - - norm_hidden_states = self.norm1(hidden_states) - - # Apply ada_norm_single - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None] + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - ada_values.unbind(dim=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - norm_hidden_states = norm_hidden_states.squeeze( - 1 - ) # TODO: Check if this is needed - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = ( - cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - ) + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=( - encoder_hidden_states if self.only_cross_attention else None - ), - attention_mask=attention_mask, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **cross_attention_kwargs, - ) - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.adaptive_norm == "none": - attn_input = self.attn2_norm(hidden_states) - else: - attn_input = hidden_states - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - norm_hidden_states = self.norm2(hidden_states) - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward( - self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size - ) - else: - ff_output = self.ff(norm_hidden_states) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - if ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.TransformerBlock - ): - skip_layer_mask = skip_layer_mask.view(-1, 1, 1) - hidden_states = hidden_states * skip_layer_mask + original_hidden_states * ( - 1.0 - skip_layer_mask - ) - - return hidden_states + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + Args: + tensor (`torch.Tensor`): The tensor to reshape. -@maybe_allow_in_graph -class Attention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): - The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): - The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to False): - Set to `True` to upcast the attention computation to `float32`. - upcast_softmax (`bool`, *optional*, defaults to False): - Set to `True` to upcast the softmax computation to `float32`. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the group norm in the cross attention. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - norm_num_groups (`int`, *optional*, defaults to `None`): - The number of groups to use for the group norm in the attention. - spatial_norm_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the spatial normalization. - out_bias (`bool`, *optional*, defaults to `True`): - Set to `True` to use a bias in the output linear layer. - scale_qk (`bool`, *optional*, defaults to `True`): - Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. - qk_norm (`str`, *optional*, defaults to None): - Set to 'layer_norm' or `rms_norm` to perform query and key normalization. - only_cross_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if - `added_kv_proj_dim` is not `None`. - eps (`float`, *optional*, defaults to 1e-5): - An additional value added to the denominator in group normalization that is used for numerical stability. - rescale_output_factor (`float`, *optional*, defaults to 1.0): - A factor to rescale the output by dividing it with this value. - residual_connection (`bool`, *optional*, defaults to `False`): - Set to `True` to add the residual connection to the output. - _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): - Set to `True` if the attention block is loaded from a deprecated state dict. - processor (`AttnProcessor`, *optional*, defaults to `None`): - The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and - `AttnProcessor` otherwise. + Returns: + `torch.Tensor`: The reshaped tensor. """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - qk_norm: Optional[str] = None, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - use_tpu_flash_attention: bool = False, - use_rope: bool = False, - ): - super().__init__() - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = ( - cross_attention_dim if cross_attention_dim is not None else query_dim - ) - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.use_tpu_flash_attention = use_tpu_flash_attention - self.use_rope = use_rope - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - if qk_norm is None: - self.q_norm = nn.Identity() - self.k_norm = nn.Identity() - elif qk_norm == "rms_norm": - self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) - self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) - elif qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) - self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) - else: - raise ValueError(f"Unsupported qk_norm method: {qk_norm}") - - self.heads = out_dim // dim_head if out_dim is not None else heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm( - num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True - ) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm( - f_channels=query_dim, zq_channels=spatial_norm_dim - ) - else: - self.spatial_norm = None - - if cross_attention_norm is None: - self.norm_cross = None - elif cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(self.cross_attention_dim) - elif cross_attention_norm == "group_norm": - if self.added_kv_proj_dim is not None: - # The given `encoder_hidden_states` are initially of shape - # (batch_size, seq_len, added_kv_proj_dim) before being projected - # to (batch_size, seq_len, cross_attention_dim). The norm is applied - # before the projection, so we need to use `added_kv_proj_dim` as - # the number of channels for the group norm. - norm_cross_num_channels = added_kv_proj_dim - else: - norm_cross_num_channels = self.cross_attention_dim - - self.norm_cross = nn.GroupNorm( - num_channels=norm_cross_num_channels, - num_groups=cross_attention_norm_num_groups, - eps=1e-5, - affine=True, - ) - else: - raise ValueError( - f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" - ) - - linear_cls = nn.Linear - - self.linear_cls = linear_cls - self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - if processor is None: - processor = AttnProcessor2_0() - self.set_processor(processor) - - def set_use_tpu_flash_attention(self): - r""" - Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. - """ - self.use_tpu_flash_attention = True - - def set_processor(self, processor: "AttnProcessor") -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info( - f"You are removing possibly trained weights of {self.processor} with {processor}" - ) - self._modules.pop("processor") - - self.processor = processor - - def get_processor( - self, return_deprecated_lora: bool = False - ) -> "AttentionProcessor": # noqa: F821 - r""" - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible - # serialization format for LoRA Attention Processors. It should be deleted once the integration - # with PEFT is completed. - is_lora_activated = { - name: module.lora_layer is not None - for name, module in self.named_modules() - if hasattr(module, "lora_layer") - } - - # 1. if no layer has a LoRA activated we can return the processor as usual - if not any(is_lora_activated.values()): - return self.processor - - # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` - is_lora_activated.pop("add_k_proj", None) - is_lora_activated.pop("add_v_proj", None) - # 2. else it is not posssible that only some layers have LoRA activated - if not all(is_lora_activated.values()): - raise ValueError( - f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" - ) - - # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor - non_lora_processor_cls_name = self.processor.__class__.__name__ - lora_processor_cls = getattr( - import_module(__name__), "LoRA" + non_lora_processor_cls_name - ) + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. - hidden_size = self.inner_dim - - # now create a LoRA attention processor from the LoRA layers - if lora_processor_cls in [ - LoRAAttnProcessor, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - ]: - kwargs = { - "cross_attention_dim": self.cross_attention_dim, - "rank": self.to_q.lora_layer.rank, - "network_alpha": self.to_q.lora_layer.network_alpha, - "q_rank": self.to_q.lora_layer.rank, - "q_hidden_size": self.to_q.lora_layer.out_features, - "k_rank": self.to_k.lora_layer.rank, - "k_hidden_size": self.to_k.lora_layer.out_features, - "v_rank": self.to_v.lora_layer.rank, - "v_hidden_size": self.to_v.lora_layer.out_features, - "out_rank": self.to_out[0].lora_layer.rank, - "out_hidden_size": self.to_out[0].lora_layer.out_features, - } - - if hasattr(self.processor, "attention_op"): - kwargs["attention_op"] = self.processor.attention_op - - lora_processor = lora_processor_cls(hidden_size, **kwargs) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict( - self.to_out[0].lora_layer.state_dict() - ) - elif lora_processor_cls == LoRAAttnAddedKVProcessor: - lora_processor = lora_processor_cls( - hidden_size, - cross_attention_dim=self.add_k_proj.weight.shape[0], - rank=self.to_q.lora_layer.rank, - network_alpha=self.to_q.lora_layer.network_alpha, - ) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict( - self.to_out[0].lora_layer.state_dict() - ) - - # only save if used - if self.add_k_proj.lora_layer is not None: - lora_processor.add_k_proj_lora.load_state_dict( - self.add_k_proj.lora_layer.state_dict() - ) - lora_processor.add_v_proj_lora.load_state_dict( - self.add_v_proj.lora_layer.state_dict() - ) - else: - lora_processor.add_k_proj_lora = None - lora_processor.add_v_proj_lora = None - else: - raise ValueError(f"{lora_processor_cls} does not exist.") - - return lora_processor - - def forward( - self, - hidden_states: torch.FloatTensor, - freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - skip_layer_mask: Optional[torch.Tensor] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - skip_layer_mask (`torch.Tensor`, *optional*): - The skip layer mask to use. If `None`, no mask is applied. - skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): - Controls which layers to skip for spatiotemporal guidance. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - - attn_parameters = set( - inspect.signature(self.processor.__call__).parameters.keys() - ) - unused_kwargs = [ - k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters - ] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by" - f" {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = { - k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters - } - - return self.processor( - self, - hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **cross_attention_kwargs, - ) + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` - is the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is - reshaped to `[batch_size * heads, seq_len, dim // heads]`. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - - head_size = self.heads - if tensor.ndim == 3: - batch_size, seq_len, dim = tensor.shape - extra_dim = 1 - else: - batch_size, extra_dim, seq_len, dim = tensor.shape - tensor = tensor.reshape( - batch_size, seq_len * extra_dim, head_size, dim // head_size - ) - tensor = tensor.permute(0, 2, 1, 3) + Returns: + `torch.Tensor`: The reshaped tensor. + """ - if out_dim == 3: - tensor = tensor.reshape( - batch_size * head_size, seq_len * extra_dim, dim // head_size - ) + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. - return tensor + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. - def get_attention_scores( - self, - query: torch.Tensor, - key: torch.Tensor, - attention_mask: torch.Tensor = None, - ) -> torch.Tensor: - r""" - Compute the attention scores. - - Args: - query (`torch.Tensor`): The query tensor. - key (`torch.Tensor`): The key tensor. - attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. - - Returns: - `torch.Tensor`: The attention probabilities/scores. - """ - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, ) - del baddbmm_input + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. - if self.upcast_softmax: - attention_scores = attention_scores.float() + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out - attention_probs = attention_scores.softmax(dim=-1) - del attention_scores - attention_probs = attention_probs.to(dtype) +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ - return attention_probs + def __init__(self): + pass - def prepare_attention_mask( - self, - attention_mask: torch.Tensor, - target_length: int, - batch_size: int, - out_dim: int = 3, - ) -> torch.Tensor: - r""" - Prepare the attention mask for the attention computation. - - Args: - attention_mask (`torch.Tensor`): - The attention mask to prepare. - target_length (`int`): - The target length of the attention mask. This is the length of the attention mask after padding. - batch_size (`int`): - The batch size, which is used to repeat the attention mask. - out_dim (`int`, *optional*, defaults to `3`): - The output dimension of the attention mask. Can be either `3` or `4`. - - Returns: - `torch.Tensor`: The prepared attention mask. - """ - head_size = self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = ( - attention_mask.shape[0], - attention_mask.shape[1], - target_length, - ) - padding = torch.zeros( - padding_shape, - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: - # we want to instead pad by (0, remaining_length), where remaining_length is: - # remaining_length: int = target_length - current_length - # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - def norm_encoder_hidden_states( - self, encoder_hidden_states: torch.Tensor - ) -> torch.Tensor: - r""" - Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the - `Attention` class. - - Args: - encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. - - Returns: - `torch.Tensor`: The normalized encoder hidden states. - """ + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if attention_mask is not None: # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones(batch_size, query.shape[2], device=query.device, dtype=torch.float32) assert ( - self.norm_cross is not None - ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - # Group norm norms along the channels dimension and expects - # input to be in the shape of (N, C, *). In this case, we want - # to norm along the hidden dimension, so we need to move - # (batch_size, sequence_length, hidden_size) -> - # (batch_size, hidden_size, sequence_length) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - else: - assert False - - return encoder_hidden_states - - @staticmethod - def apply_rotary_emb( - input_tensor: torch.Tensor, - freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert query.shape[2] % 128 == 0, f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert key.shape[2] % 128 == 0, f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_a = hidden_states_a.to(query.dtype) + + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states -class AttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ - def __init__(self): - pass + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - skip_layer_mask: Optional[torch.FloatTensor] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) + residual = hidden_states - if skip_layer_mask is not None: - skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) - - if (attention_mask is not None) and (not attn.use_tpu_flash_attention): - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view( - batch_size, attn.heads, -1, attention_mask.shape[-1] - ) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( - 1, 2 - ) - - query = attn.to_q(hidden_states) - query = attn.q_norm(query) - - if encoder_hidden_states is not None: - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states - ) - key = attn.to_k(encoder_hidden_states) - key = attn.k_norm(key) - else: # if no context provided do self-attention - encoder_hidden_states = hidden_states - key = attn.to_k(hidden_states) - key = attn.k_norm(key) - if attn.use_rope: - key = attn.apply_rotary_emb(key, freqs_cis) - query = attn.apply_rotary_emb(query, freqs_cis) - - value = attn.to_v(encoder_hidden_states) - value_for_stg = value - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - - if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' - q_segment_indexes = None - if ( - attention_mask is not None - ): # if mask is required need to tune both segmenIds fields - # attention_mask = torch.squeeze(attention_mask).to(torch.float32) - attention_mask = attention_mask.to(torch.float32) - q_segment_indexes = torch.ones( - batch_size, query.shape[2], device=query.device, dtype=torch.float32 - ) - assert ( - attention_mask.shape[1] == key.shape[2] - ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" - - assert ( - query.shape[2] % 128 == 0 - ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" - assert ( - key.shape[2] % 128 == 0 - ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" - - # run the TPU kernel implemented in jax with pallas - hidden_states_a = flash_attention( - q=query, - k=key, - v=value, - q_segment_ids=q_segment_indexes, - kv_segment_ids=attention_mask, - sm_scale=attn.scale, - ) - else: - hidden_states_a = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - - hidden_states_a = hidden_states_a.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states_a = hidden_states_a.to(query.dtype) - - if ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.AttentionSkip - ): - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( - 1.0 - skip_layer_mask - ) - elif ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.AttentionValues - ): - hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( - 1.0 - skip_layer_mask - ) - else: - hidden_states = hidden_states_a - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.Residual - ): - skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) - - if attn.residual_connection: - if ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.Residual - ): - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( - 1, 2 - ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states - ) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) - query = attn.q_norm(query) - key = attn.k_norm(key) + query = attn.q_norm(query) + key = attn.k_norm(key) - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if attn.residual_connection: - hidden_states = hidden_states + residual + if attn.residual_connection: + hidden_states = hidden_states + residual - hidden_states = hidden_states / attn.rescale_output_factor + hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - linear_cls = nn.Linear - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim, bias=bias) - elif activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim, bias=bias) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, bias=bias) - else: - raise ValueError(f"Unsupported activation function: {activation_fn}") - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: - compatible_cls = (GEGLU, LoRACompatibleLinear) - for module in self.net: - if isinstance(module, compatible_cls): - hidden_states = module(hidden_states, scale) - else: - hidden_states = module(hidden_states) - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py index 6461039fb..b4ca6b52f 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -31,115 +31,111 @@ def get_timestep_embedding( scale: float = 1, max_period: int = 10000, ): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] - # scale embeddings - emb = scale * emb + # scale embeddings + emb = scale * emb - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) - grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) - grid = grid.reshape([3, 1, w, h, f]) - pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) - pos_embed = pos_embed.transpose(1, 0, 2, 3) - return rearrange(pos_embed, "h w f c -> (f h w) c") + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 3 != 0: - raise ValueError("embed_dim must be divisible by 3") + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") - # use half of dimensions to encode grid_h - emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) - emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) - return emb + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) - pos_shape = pos.shape + pos_shape = pos.shape - pos = pos.reshape(-1) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - out = out.reshape([*pos_shape, -1])[0] + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) - return emb + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb class SinusoidalPositionalEmbedding(nn.Module): - """Apply positional information to a sequence of embeddings. - - Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to - them - - Args: - embed_dim: (int): Dimension of the positional embedding. - max_seq_length: Maximum sequence length to apply positional embeddings - - """ - - def __init__(self, embed_dim: int, max_seq_length: int = 32): - super().__init__() - position = torch.arange(max_seq_length).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) - ) - pe = torch.zeros(1, max_seq_length, embed_dim) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) - - def forward(self, x): - _, seq_length, _ = x.shape - x = x + self.pe[:, :seq_length] - return x + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py index b34df6ed3..d53b4d7ca 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -24,77 +24,75 @@ class Patchifier(ConfigMixin, ABC): - def __init__(self, patch_size: int): - super().__init__() - self._patch_size = (1, patch_size, patch_size) - @abstractmethod - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - raise NotImplementedError("Patchify method not implemented") + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) - @abstractmethod - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - pass + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") - @property - def patch_size(self): - return self._patch_size + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass - def get_latent_coords( - self, latent_num_frames, latent_height, latent_width, batch_size, device - ): - """ - Return a tensor of shape [batch_size, 3, num_patches] containing the - top-left corner latent coordinates of each latent patch. - The tensor is repeated for each batch element. - """ - latent_sample_coords = torch.meshgrid( - torch.arange(0, latent_num_frames, self._patch_size[0], device=device), - torch.arange(0, latent_height, self._patch_size[1], device=device), - torch.arange(0, latent_width, self._patch_size[2], device=device), - ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size - ) - return latent_coords + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords(self, latent_num_frames, latent_height, latent_width, batch_size, device): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange(latent_coords, "b c f h w -> b c (f h w)", b=batch_size) + return latent_coords class SymmetricPatchifier(Patchifier): - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - b, _, f, h, w = latents.shape - latent_coords = self.get_latent_coords(f, h, w, b, latents.device) - latents = rearrange( - latents, - "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", - p1=self._patch_size[0], - p2=self._patch_size[1], - p3=self._patch_size[2], - ) - return latents, latent_coords - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - output_height = output_height // self._patch_size[1] - output_width = output_width // self._patch_size[2] - latents = rearrange( - latents, - "b (f h w) (c p q) -> b c f (h p) (w q)", - h=output_height, - w=output_width, - p=self._patch_size[1], - q=self._patch_size[2], - ) - return latents + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py index 74366830c..2ade88b86 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer3d.py @@ -49,475 +49,424 @@ @dataclass class Transformer3DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. + """ + The output of [`Transformer2DModel`]. - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ - sample: torch.FloatTensor + sample: torch.FloatTensor class Transformer3DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - num_vector_embeds: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' - standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - attention_type: str = "default", - caption_channels: int = None, - use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') - qk_norm: Optional[str] = None, - positional_embedding_type: str = "rope", - positional_embedding_theta: Optional[float] = None, - positional_embedding_max_pos: Optional[List[int]] = None, - timestep_scale_multiplier: Optional[float] = None, - causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated - ): - super().__init__() - self.use_tpu_flash_attention = ( - use_tpu_flash_attention # FIXME: push config down to the attention modules - ) - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.inner_dim = inner_dim - self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) - self.positional_embedding_type = positional_embedding_type - self.positional_embedding_theta = positional_embedding_theta - self.positional_embedding_max_pos = positional_embedding_max_pos - self.use_rope = self.positional_embedding_type == "rope" - self.timestep_scale_multiplier = timestep_scale_multiplier - - if self.positional_embedding_type == "absolute": - raise ValueError("Absolute positional embedding is no longer supported") - elif self.positional_embedding_type == "rope": - if positional_embedding_theta is None: - raise ValueError( - "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" - ) - if positional_embedding_max_pos is None: - raise ValueError( - "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - adaptive_norm=adaptive_norm, - standardization_norm=standardization_norm, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, - use_tpu_flash_attention=use_tpu_flash_attention, - qk_norm=qk_norm, - use_rope=self.use_rope, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter( - torch.randn(2, inner_dim) / inner_dim**0.5 - ) - self.proj_out = nn.Linear(inner_dim, self.out_channels) - - self.adaln_single = AdaLayerNormSingle( - inner_dim, use_additional_conditions=False - ) - if adaptive_norm == "single_scale": - self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) - - self.caption_projection = None - if caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=inner_dim - ) - - self.gradient_checkpointing = False - - def set_use_tpu_flash_attention(self): - r""" - Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU - attention kernel. - """ - logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") - self.use_tpu_flash_attention = True - # push config down to the attention modules - for block in self.transformer_blocks: - block.set_use_tpu_flash_attention() - - def create_skip_layer_mask( - self, - batch_size: int, - num_conds: int, - ptb_index: int, - skip_block_list: Optional[List[int]] = None, - ): - if skip_block_list is None or len(skip_block_list) == 0: - return None - num_layers = len(self.transformer_blocks) - mask = torch.ones( - (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype - ) - for block_idx in skip_block_list: - mask[block_idx, ptb_index::num_conds] = 0 - return mask - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def get_fractional_positions(self, indices_grid): - fractional_positions = torch.stack( - [ - indices_grid[:, i] / self.positional_embedding_max_pos[i] - for i in range(3) - ], - dim=-1, - ) - return fractional_positions - - def precompute_freqs_cis(self, indices_grid, spacing="exp"): - dtype = torch.float32 # We need full precision in the freqs_cis computation. - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - device = fractional_positions.device - if spacing == "exp": - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - elif spacing == "exp_2": - indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) - indices = indices.to(dtype=dtype) - elif spacing == "linear": - indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) - elif spacing == "sqrt": - indices = torch.linspace( - start**2, end**2, dim // 6, device=device, dtype=dtype - ).sqrt() - - indices = indices * math.pi / 2 - - if spacing == "exp_2": - freqs = ( - (indices * fractional_positions.unsqueeze(-1)) - .transpose(-1, -2) - .flatten(2) - ) - else: - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) - if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(self.dtype), sin_freq.to(self.dtype) - - def load_state_dict( - self, - state_dict: Dict, - *args, - **kwargs, - ): - if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): #noqa: C419 - state_dict = { - key.replace("model.diffusion_model.", ""): value - for key, value in state_dict.items() - if key.startswith("model.diffusion_model.") - } - super().load_state_dict(state_dict, *args, **kwargs) - - @classmethod - def from_pretrained( - cls, - pretrained_model_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_path = Path(pretrained_model_path) - if pretrained_model_path.is_dir(): - config_path = pretrained_model_path / "transformer" / "config.json" - with open(config_path, "r") as f: - config = make_hashable_key(json.load(f)) - - assert config in diffusers_and_ours_config_mapping, ( - "Provided diffusers checkpoint config for transformer is not suppported. " - "We only support diffusers configs found in Lightricks/LTX-Video." + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = use_tpu_flash_attention # FIXME: push config down to the attention modules + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError("If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined") + if positional_embedding_max_pos is None: + raise ValueError("If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined") + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones((num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace(start**2, end**2, dim // 6, device=device, dtype=dtype).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = (indices * fractional_positions.unsqueeze(-1)).transpose(-1, -2).flatten(2) + else: + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): # noqa: C419 + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = pretrained_model_path / "transformer" / "diffusion_pytorch_model*.safetensors" + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(".safetensors"): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. - config = diffusers_and_ours_config_mapping[config] - state_dict = {} - ckpt_paths = ( - pretrained_model_path - / "transformer" - / "diffusion_pytorch_model*.safetensors" - ) - dict_list = glob.glob(str(ckpt_paths)) - for dict_path in dict_list: - part_dict = {} - with safe_open(dict_path, framework="pt", device="cpu") as f: - for k in f.keys(): - part_dict[k] = f.get_tensor(k) - state_dict.update(part_dict) - - for key in list(state_dict.keys()): - new_key = key - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - state_dict[new_key] = state_dict.pop(key) - - with torch.device("meta"): - transformer = cls.from_config(config) - transformer.load_state_dict(state_dict, assign=True, strict=True) - elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( - ".safetensors" - ): - comfy_single_file_state_dict = {} - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - for k in f.keys(): - comfy_single_file_state_dict[k] = f.get_tensor(k) - configs = json.loads(metadata["config"]) - transformer_config = configs["transformer"] - with torch.device("meta"): - transformer = Transformer3DModel.from_config(transformer_config) - transformer.load_state_dict(comfy_single_file_state_dict, assign=True) - return transformer - - def forward( - self, - hidden_states: torch.Tensor, - indices_grid: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - skip_layer_mask: Optional[torch.Tensor] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - skip_layer_mask ( `torch.Tensor`, *optional*): - A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position - `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. - skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): - Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # for tpu attention offload 2d token masks are used. No need to transform. - if not self.use_tpu_flash_attention: - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = ( - 1 - encoder_attention_mask.to(hidden_states.dtype) - ) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - hidden_states = self.patchify_proj(hidden_states) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - freqs_cis = self.precompute_freqs_cis(indices_grid) - - batch_size = hidden_states.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + (skip_layer_mask[block_idx] if skip_layer_mask is not None else None), + skip_layer_strategy, + **ckpt_kwargs, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=(skip_layer_mask[block_idx] if skip_layer_mask is not None else None), + skip_layer_strategy=skip_layer_strategy, ) - # 2. Blocks - if self.caption_projection is not None: - batch_size = hidden_states.shape[0] - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view( - batch_size, -1, hidden_states.shape[-1] - ) - - for block_idx, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - freqs_cis, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels, - ( - skip_layer_mask[block_idx] - if skip_layer_mask is not None - else None - ), - skip_layer_strategy, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - freqs_cis=freqs_cis, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - skip_layer_mask=( - skip_layer_mask[block_idx] - if skip_layer_mask is not None - else None - ), - skip_layer_strategy=skip_layer_strategy, - ) - - # 3. Output - scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if not return_dict: - return (hidden_states,) - - return Transformer3DModelOutput(sample=hidden_states) + # 3. Output + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py index 84b416d52..bd4d115f9 100644 --- a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -25,8 +25,6 @@ import optax import orbax.checkpoint as ocp from safetensors.torch import load_file -import requests -from urllib.parse import urljoin from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel from maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d import Transformer3DModel as Transformer3DModel @@ -34,69 +32,6 @@ from huggingface_hub import hf_hub_download import os -import importlib - - -def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import): - """ - Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. - - Args: - github_base_url (str): The base URL of the GitHub repo. - base_path (str): The base path where the new folder will be created. - target_folder_name (str): The name of the folder to create. - files_to_move (list): A list of file names to download and move. - module_to_import (str): The full module path to import. - """ - - target_path = os.path.join(base_path, target_folder_name) - - try: - # Create the target directory - os.makedirs(target_path, exist_ok=True) - print(f"Created directory: {target_path}") - - # Download and move files - for file_name in files_to_move: - file_url = urljoin(github_base_url, file_name) - destination_path = os.path.join(target_path, file_name) - - try: - response = requests.get(file_url, stream=True) - response.raise_for_status() - - with open(destination_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - - print(f"Downloaded and moved: {file_name} -> {destination_path}") - - except requests.exceptions.RequestException as e: - print(f"Error downloading {file_name}: {e}") - except OSError as e: - print(f"Error writing file {file_name}: {e}") - print("Files downloaded and moved successfully.") - - # Verify that the folder exists - if not os.path.exists(target_path): - print(f"Error: Target folder {target_path} does not exist after files download.") - # Dynamically import the module - try: - imported_module = importlib.import_module(module_to_import) - print(f"Module '{module_to_import}' imported successfully.") - # Access the class - transformer_class = getattr(imported_module, "Transformer3DModel") - print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") - return transformer_class - except ImportError as e: - print(f"Error importing module '{module_to_import}': {e}") - except AttributeError as e: - print(f"Error accessing class 'Transformer3DModel': {e}") - - except OSError as e: - print(f"Error during file system operation: {e}") - except Exception as e: - print(f"An unexpected error occurred: {e}") class Checkpointer: diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index e5f417bed..0ca816f9e 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -28,6 +28,7 @@ AutoModelForCausalLM, AutoProcessor, ) +from maxdiffusion import max_logging from huggingface_hub import hf_hub_download from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, @@ -59,10 +60,10 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids): # Note: reference shape annotated for first pass default inference parameters - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32 - print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32 - print("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32 - print( + max_logging.log("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32 + max_logging.log("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32 + max_logging.log("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32 + max_logging.log( "encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype ) # (3, 256) int32 @@ -156,7 +157,7 @@ def load_transformer(cls, config): mesh=mesh, weights_init_fn=weights_init_fn, checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", + checkpoint_item="ltxvid_transformer", model_params=None, training=False, ) @@ -993,7 +994,7 @@ def __call__( skip_block_list=config.first_pass["skip_block_list"], ) latents = result - print("first pass done") + max_logging.log("first pass done") latent_upsampler = self.load_latent_upsampler(config) upsampled_latents = self._upsample_latents(latent_upsampler, latents) upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py index 6c39730df..f3f2dcba6 100644 --- a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -167,6 +167,9 @@ class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + ''' + Note: shifting and stochastic sampling not tested + ''' dtype: jnp.dtype order = 1 From 8df8dbd6d00a95fc1f65fec8c60ad648a7ea79e4 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 25 Jul 2025 21:23:56 +0000 Subject: [PATCH 69/69] style fix --- src/maxdiffusion/schedulers/scheduling_rectified_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py index f3f2dcba6..c4b2657e3 100644 --- a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -167,9 +167,9 @@ class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): - ''' + """ Note: shifting and stochastic sampling not tested - ''' + """ dtype: jnp.dtype order = 1