From 377619088bcd5b04fcab5413b3dfd740c32eb27a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 19:05:46 +0000 Subject: [PATCH 01/34] set up files for ltxvid --- src/maxdiffusion/__init__.py | 721 +++++++++++++------------ src/maxdiffusion/configs/ltx_video.yml | 50 ++ src/maxdiffusion/generate_ltx_video.py | 73 +++ src/maxdiffusion/models/__init__.py | 5 +- 4 files changed, 490 insertions(+), 359 deletions(-) create mode 100644 src/maxdiffusion/configs/ltx_video.yml create mode 100644 src/maxdiffusion/generate_ltx_video.py diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..677d64e4e 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,438 +65,447 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [ + name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [ + name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend( + ["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = [ + "FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = [ + "FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = [ + "Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, - ) + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..ac333d329 --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,50 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..81c832c3c --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,73 @@ +from absl import app +from typing import Sequence +import jax +import json +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, +) +from jax.sharding import Mesh, PartitionSpec as P + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + + +def run(config): + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) + + # load in model config + config_path = os.path.join( + base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + transformer = Transformer3DModel( + **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights( + key, batch_size, text_tokens, num_tokens, features, eval_only=False) + + key, split_key = jax.random.split(key) + weights_init_fn = functools.partial( + transformer.init_weights, + split_key, + batch_size, + text_tokens, + num_tokens, + features, + eval_only=False + ) + + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys From 13656fb457723f52e935e8afab69a983d5cfd68a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 20:32:05 +0000 Subject: [PATCH 02/34] ltx-video-transformer-setup --- src/maxdiffusion/configs/ltx_video.yml | 15 + src/maxdiffusion/generate_ltx_video.py | 29 +- src/maxdiffusion/models/__init__.py | 17 +- src/maxdiffusion/models/ltx_video/__init__.py | 0 .../models/ltx_video/gradient_checkpoint.py | 70 ++ src/maxdiffusion/models/ltx_video/linear.py | 111 ++ .../models/ltx_video/repeatable_layer.py | 105 ++ .../models/ltx_video/transformers/__init__.py | 0 .../ltx_video/transformers/activations.py | 176 ++++ .../models/ltx_video/transformers/adaln.py | 201 ++++ .../ltx_video/transformers/attention.py | 945 ++++++++++++++++++ .../transformers/caption_projection.py | 40 + .../ltx_video/transformers/transformer3d.py | 322 ++++++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 24 + 14 files changed, 2036 insertions(+), 19 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/gradient_checkpoint.py create mode 100644 src/maxdiffusion/models/ltx_video/linear.py create mode 100644 src/maxdiffusion/models/ltx_video/repeatable_layer.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/activations.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/adaln.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/caption_projection.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/transformer3d.py create mode 100644 src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index ac333d329..954922521 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,3 +1,18 @@ +# Copyright 2025 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + #hardware hardware: 'tpu' skip_jax_distributed_system: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 81c832c3c..6d96aa8c2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,20 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + from absl import app from typing import Sequence import jax @@ -50,17 +67,7 @@ def run(config): text_tokens, num_tokens, features, - eval_only=False - ) - - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, + eval_only=True ) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 96a6f1286..20c27ab20 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,14 +25,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..f32cc9459 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..fd92c695d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,111 @@ +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features):], + kernel_shape[-len(features):], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..882f21ace --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,105 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = { + nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( + self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn( + self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..3e1fd6d6e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,176 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from diffusers.utils.deprecation_utils import deprecate + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError( + f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..374af6acc --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,201 @@ +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint( + emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate( + [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint( + sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + + """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb( + timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..4ade671c7 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,945 @@ +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto + +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in [ + "single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name( + hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", + "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * \ + (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm( + hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any( + device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print( + "Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get( + "scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", + "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint( + hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint( + encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape( + hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states( + encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones( + (batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", + "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes( + hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + \ + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1]( + hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes( + hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape( + skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad( + attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), + kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError( + f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning( + self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype( + jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * \ + jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) + * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError( + f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..dff8b8c62 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,40 @@ +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..4368c35fb --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,322 @@ +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning( + scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm( + epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal( + split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal( + split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal( + split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal( + split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection( + encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, + :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", + "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, + axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..02f13b15a --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,24 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000 +} \ No newline at end of file From 7bed4f99d1b8c19eb1ee622ee895f1ca31b1b870 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 21:12:17 +0000 Subject: [PATCH 03/34] formatting --- src/maxdiffusion/__init__.py | 723 ++++--- src/maxdiffusion/generate_ltx_video.py | 60 +- src/maxdiffusion/models/__init__.py | 17 +- .../models/ltx_video/gradient_checkpoint.py | 102 +- src/maxdiffusion/models/ltx_video/linear.py | 170 +- .../models/ltx_video/repeatable_layer.py | 129 +- .../ltx_video/transformers/activations.py | 275 ++- .../models/ltx_video/transformers/adaln.py | 328 ++-- .../ltx_video/transformers/attention.py | 1696 ++++++++--------- .../transformers/caption_projection.py | 60 +- .../ltx_video/transformers/transformer3d.py | 585 +++--- 11 files changed, 2023 insertions(+), 2122 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 677d64e4e..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,447 +65,440 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [ - name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [ - name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - ["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [ - name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = [ - "FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = [ - "FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["models.ltx_video.transformers.transformer3d"] = [ - "Transformer3DModel"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.ltx_video.transformers.transformer3d import Transformer3DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6d96aa8c2..d05203f5c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -14,67 +14,55 @@ limitations under the License. """ - from absl import app from typing import Sequence import jax import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os -import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, - setup_initial_state, ) -from jax.sharding import Mesh, PartitionSpec as P def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + # load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 - base_dir = os.path.dirname(__file__) + transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) - # load in model config - config_path = os.path.join( - base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) + key, split_key = jax.random.split(key) - transformer = Transformer3DModel( - **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights( - key, batch_size, text_tokens, num_tokens, features, eval_only=False) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( - transformer.init_weights, - split_key, - batch_size, - text_tokens, - num_tokens, - features, - eval_only=True - ) + weights_init_fn = functools.partial( + transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + ) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 20c27ab20..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,15 +25,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index f32cc9459..ef8c530ba 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -8,63 +8,63 @@ class GradientCheckpointType(Enum): - """ - Defines the type of the gradient checkpoint we will have + """ + Defines the type of the gradient checkpoint we will have - NONE - means no gradient checkpoint - FULL - means full gradient checkpoint, wherever possible (minimum memory usage) - MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, - except for ones that involve batch dimension - that means that all attention and projection - layers will have gradient checkpoint, but not the backward with respect to the parameters - """ + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ - NONE = auto() - FULL = auto() - MATMUL_WITHOUT_BATCH = auto() + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() - @classmethod - def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": - """ - Constructs the gradient checkpoint type from a string + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string - Args: - s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. - Returns: - GradientCheckpointType: The policy that corresponds to the string - """ - if s is None: - s = "none" - return GradientCheckpointType[s.upper()] + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] - def to_jax_policy(self): - """ - Converts the gradient checkpoint type to a jax policy - """ - match self: - case GradientCheckpointType.NONE: - return SKIP_GRADIENT_CHECKPOINT_KEY - case GradientCheckpointType.FULL: - return None - case GradientCheckpointType.MATMUL_WITHOUT_BATCH: - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nn.Module) -> nn.Module: - """ - Applies a gradient checkpoint policy to a module - if no policy is needed, it will return the module as is + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is - Args: - module (nn.Module): the module to apply the policy to + Args: + module (nn.Module): the module to apply the policy to - Returns: - nn.Module: the module with the policy applied - """ - policy = self.to_jax_policy() - if policy == SKIP_GRADIENT_CHECKPOINT_KEY: - return module - return nn.remat( # pylint: disable=invalid-name - module, - prevent_cse=False, - policy=policy, - ) + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index fd92c695d..31b21cdd9 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -13,99 +13,97 @@ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) -NdInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] -KernelInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] class DenseGeneral(nn.Module): - """A linear transformation with flexible axes. - - Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - weight_dtype: the dtype of the weights (default: float32). - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - use_bias: whether to add bias in linear transformation. - bias_norm: whether to add normalization before adding bias. - quant: quantization config, defaults to None implying no quantization. + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. """ - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - weight_dtype: jnp.dtype = jnp.float32 - dtype: np.dtype = jnp.float32 - kernel_init: KernelInitializer = lecun_normal() - kernel_axes: Tuple[Optional[str], ...] = () - use_bias: bool = False - matmul_precision: str = "default" - - bias_init: Initializer = jax.nn.initializers.constant(0.0) - - @nn.compact - def __call__(self, inputs: jax.Array) -> jax.Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - - def compute_dot_general(inputs, kernel, axis, contract_ind): - """Computes a dot_general operation that may be quantized.""" - dot_general = jax.lax.dot_general - matmul_precision = jax.lax.Precision(self.matmul_precision) - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) - - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) - kernel = self.param( - "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - ) - kernel = jnp.asarray(kernel, self.dtype) - - contract_ind = tuple(range(0, len(axis))) - output = compute_dot_general(inputs, kernel, axis, contract_ind) - - if self.use_bias: - bias_axes, bias_shape = ( - self.kernel_axes[-len(features):], - kernel_shape[-len(features):], - ) - bias = self.param( - "bias", - nn.with_logical_partitioning(self.bias_init, bias_axes), - bias_shape, - self.weight_dtype, - ) - bias = jnp.asarray(bias, self.dtype) - - output += bias - return output + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + # kernel_in_axis = np.arange(len(axis)) + # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 882f21ace..aaed41048 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -7,99 +7,96 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - @nn.compact - def __call__(self, *args) -> Tuple[jax.Array, None]: - """ - jax carry-op format of block - assumes the input contains an input tensor to the block along with kwargs that might be send to the block - kwargs are assumed to have static role, while the input changes between cycles + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles - Returns: - Tuple[jax.Array, None]: Output tensor from the block - """ - mod = self.module(*self.module_init_args, **self.module_init_kwargs) - output = mod(*args) - return output, None + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = { - nn.PARTITION_NAME: self.pspec_name} - - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( - self.param_scan_axis) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, # Separate params per timestep - split_rngs={"params": True}, - in_axes=(nn.broadcast,) * (len(args) - 1), - length=self.num_layers, - **scan_kwargs, - ) - wrapped_function = scan_fn( - self.module, self.module_init_args, self.module_init_kwargs) - x, _ = wrapped_function(*args) - return x + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 3e1fd6d6e..4a78b48ea 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -22,155 +22,154 @@ @jax.jit def approximate_gelu(x: jax.Array) -> jax.Array: - """ - Computes Gaussian Error Linear Unit (GELU) activation function + """ + Computes Gaussian Error Linear Unit (GELU) activation function - Args: - x (jax.Array): The input tensor + Args: + x (jax.Array): The input tensor - jax.Array: The output tensor - """ - # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs - # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior - if x.dtype in (jax.numpy.float64,): - x = x.clip(-10, None) - return jax.nn.gelu(x, approximate=True) + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) def get_activation(act_fn: str): - """Returns the activation function from string.""" - act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] - raise ValueError(f"Unsupported activation function: {act_fn}") + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - approximate: str = "none" - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def gelu(self, gate: jax.Array) -> jax.Array: - approximate_to_tanh = self.approximate == "tanh" - if approximate_to_tanh: - return approximate_gelu(gate) - else: - return jax.nn.gelu(gate, approximate=False) - - @nn.compact - def __call__(self, hidden_states): - if self.approximate not in ("none", "tanh"): - raise ValueError( - f"approximate must be 'none' or 'tanh', got {self.approximate}") - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - hidden_states = proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states class GEGLU(nn.Module): - r""" - A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - proj = DenseGeneral( - features=self.dim_out * 2, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - - hidden_states = proj(hidden_states) - hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) - return hidden_states * jax.nn.gelu(gate, approximate=False) + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) class ApproximateGELU(nn.Module): - r""" - The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this - [paper](https://arxiv.org/abs/1606.08415). - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, x): - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - x = proj(x) - return x * jax.nn.sigmoid(1.702 * x) + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 374af6acc..4bc27e8bc 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -17,185 +17,177 @@ def get_timestep_embedding_multidim( scale: float = 1, max_period: int = 10000, ) -> jnp.ndarray: - """ - Computes sinusoidal timestep embeddings while preserving the original dimensions. - No reshaping to 1D is performed at any stage. - - Args: - timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. - embedding_dim (int): The dimension of the output. - flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) - or `sin, cos` (if False). - downscale_freq_shift (float): Controls the delta between frequencies between dimensions. - scale (float): Scaling factor applied to the embeddings. - max_period (int): Controls the maximum frequency of the embeddings. - - Returns: - jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. - """ - half_dim = embedding_dim // 2 - exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) - exponent = exponent / (half_dim - downscale_freq_shift) - shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) - emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape - emb = nn.with_logical_constraint( - emb, ("activation_batch", "activation_norm_length", "activation_embed")) - # Broadcasting to match shape (*timesteps.shape, half_dim) - emb = timesteps[..., None] * emb - emb = scale * emb - # Shape (*timesteps.shape, embedding_dim) - emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) - if flip_sin_to_cos: - emb = jnp.concatenate( - [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) - - return emb + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb class TimestepEmbedding(nn.Module): - in_channels: int - time_embed_dim: int - act_fn: str = "silu" - out_dim: Optional[int] = None - sample_proj_bias: bool = True - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers efficiently""" - self.linear_1 = DenseGeneral( - self.time_embed_dim, - use_bias=self.sample_proj_bias, - kernel_axes=(None, "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - ) - - self.act = get_activation(self.act_fn) - time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim - self.linear_2 = DenseGeneral( - time_embed_dim_out, - use_bias=self.sample_proj_bias, - kernel_axes=("embed", "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - ) - - def __call__(self, sample, condition=None): - sample = nn.with_logical_constraint( - sample, ("activation_batch", "activation_norm_length", "activation_embed")) - sample = self.linear_1(sample) - sample = self.act(sample) - sample = self.linear_2(sample) - return sample + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample class Timesteps(nn.Module): - num_channels: int - flip_sin_to_cos: bool - downscale_freq_shift: float - scale: int = 1 - - def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: - t_emb = get_timestep_embedding_multidim( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - return t_emb + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - """ - - """ - - embedding_dim: int - size_emb_dim: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize sub-modules.""" - self.outdim = self.size_emb_dim - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, - time_embed_dim=self.embedding_dim, - name="timestep_embedder", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.astype(hidden_dtype)) - return timesteps_emb + """ """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb class AdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ + Compute AdaLayerNorm-Single modulation. - embedding_dim: int - embedding_coefficient: int = 6 - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - self.emb = AlphaCombinedTimestepSizeEmbeddings( - self.embedding_dim, - size_emb_dim=self.embedding_dim // 3, - name="emb", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - self.silu = jax.nn.silu - self.linear = DenseGeneral( - self.embedding_coefficient * self.embedding_dim, - use_bias=True, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear", - ) - - def __call__( - self, - timestep: jnp.ndarray, - added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, - batch_size: Optional[int] = None, - hidden_dtype: Optional[jnp.dtype] = None, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Compute AdaLayerNorm-Single modulation. - - Returns: - Tuple: - - Processed embedding after SiLU + linear transformation. - - Original embedded timestep. - """ - embedded_timestep = self.emb( - timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) - return self.linear(self.silu(embedded_timestep)), embedded_timestep + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4ade671c7..5d12e7813 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -25,921 +25,869 @@ class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - def __call__(self, x): - return x + + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in [ - "single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, + name="to_out.0", matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print( - "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name( - hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] - - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", - "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) - ) - norm_hidden_states = norm_hidden_states * \ - (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - attn_output = nn.with_logical_constraint( - attn_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm( - hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint( - attn_input, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) +class AttentionOp(nn.Module): - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * \ - (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint( - ff_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) - - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="to_out.0", - matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any( - device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print( - "Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = { - k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - assert cross_attention_kwargs.get( - "scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", - "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint( - hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint( - encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape( - hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states( - encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") + ** SRAM cache size for TPU + V5P - 1MB SRAM per core - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones( - (batch_size, query.shape[2]), dtype=jnp.float32) - hidden_states_a = self.attention( - query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype - ) +class ExplicitAttention(nn.Module): - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", - "activation_length", "activation_kv") - ) + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v - hidden_states_a = jnp.reshape(jnp.swapaxes( - hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + \ - hidden_states * (1.0 - skip_layer_mask) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1]( - hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes( - hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape( - skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad( - attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. -class AttentionOp(nn.Module): - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), - kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError( - f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - None, - None, - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. - - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can - - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM - - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size - - ** SRAM cache size for TPU - V5P - 1MB SRAM per core - - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used - - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + Returns: + jax.Array: Normed data + """ -class ExplicitAttention(nn.Module): - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. - - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. - - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data - - Returns: - jax.Array: Normed data - """ - - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning( - self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype( - jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * \ - jax.lax.rsqrt(variance + self.epsilon) - - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) - * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError( - f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index dff8b8c62..f2b1af101 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -6,35 +6,35 @@ class CaptionProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - """ + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ - in_features: int - hidden_size: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" - @nn.compact - def __call__(self, caption): - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - )(caption) - hidden_states = approximate_gelu(hidden_states) - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - )(hidden_states) - return hidden_states + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 4368c35fb..dac8e6280 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -13,310 +13,297 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning( - scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm( - epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal( - split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal( - split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal( - split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal( - split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) - if eval_only: - return jax.eval_shape( - self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection( - encoder_hidden_states) - - if self.num_layers > 0: - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - ) - # Output processing - - scale_shift_values = ( - self.scale_shift_table[jnp.newaxis, jnp.newaxis, - :, :] + embedded_timestep[:, :, jnp.newaxis] - ) - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", - "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. + """ + Computes log of x with defined base. - Args: - x (jax.Array): log value - base (jax.Array): base of the log + Args: + x (jax.Array): log value + base (jax.Array): base of the log - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] - for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, - axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) From 7e098c586fad8874fa0d62912a42a11f159c9545 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 22:56:13 +0000 Subject: [PATCH 04/34] format fixed --- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 10 +-- .../ltx_video/transformers/attention.py | 2 +- .../ltx_video/transformers/transformer3d.py | 68 +++++++------------ .../ltx_video/xora_v1.2-13B-balanced-128.json | 3 +- 5 files changed, 33 insertions(+), 52 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 954922521..eb44d253b 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -62,4 +62,4 @@ cache_latents_text_encoder_outputs: True per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 -jit_initializers: True \ No newline at end of file +jit_initializers: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index d05203f5c..6efe564b2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,11 +20,13 @@ import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os +import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, ) +from jax.sharding import Mesh def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): @@ -38,7 +40,7 @@ def run(config): key = jax.random.PRNGKey(0) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) @@ -49,12 +51,10 @@ def run(config): model_config = json.load(f) transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 key, split_key = jax.random.split(key) - - - weights_init_fn = functools.partial( + weights_init_fn = functools.partial( # noqa F841 transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True ) diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 5d12e7813..4812b89ba 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -438,7 +438,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index dac8e6280..cf599f26c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module): only_cross_attention: bool = False double_self_attention: bool = False upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-5 attention_type: str = "default" caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') qk_norm: Optional[str] = None positional_embedding_type: str = "rope" positional_embedding_theta: Optional[float] = None @@ -98,7 +96,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( + module_init_kwargs=dict( # noqa C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, @@ -139,46 +137,30 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + if eval_only: return jax.eval_shape( self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, + key, + **example_inputs, )["params"] else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] + return self.init(key, **example_inputs)["params"] def __call__( self, @@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: @nn.compact def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 + dtype = jnp.float32 # We need full precision in the freqs_cis computation. dim = self.inner_dim theta = self.positional_embedding_theta @@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: indices = indices * jnp.pi / 2 freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 cos_freq = jnp.cos(freqs).repeat(2, axis=-1) sin_freq = jnp.sin(freqs).repeat(2, axis=-1) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 02f13b15a..75b16b011 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -20,5 +20,6 @@ "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000 + "timestep_scale_multiplier": 1000, + "in_channels": 128 } \ No newline at end of file From e18128c3f19db8a8cba15195397281addaf0558a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 18:17:48 +0000 Subject: [PATCH 05/34] transformer step and test --- src/maxdiffusion/configs/ltx_video.yml | 24 ++- src/maxdiffusion/generate_ltx_video.py | 171 ++++++++++++--- src/maxdiffusion/max_utils.py | 21 +- .../ltx_video/xora_v1.2-13B-balanced-128.json | 1 + src/maxdiffusion/pyconfig.py | 17 ++ .../tests/ltx_transformer_step_test.py | 198 ++++++++++++++++++ .../tests/ltx_vid_transformer_test_ref_pred | Bin 0 -> 263834 bytes 7 files changed, 402 insertions(+), 30 deletions(-) create mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py create mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index eb44d253b..8f1ee8a7d 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,12 +22,25 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', ['data','fsdp']], @@ -40,13 +53,19 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 + ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 @@ -63,3 +82,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6efe564b2..bad791cee 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,23 +1,8 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - from absl import app from typing import Sequence import jax import json +from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os import functools @@ -25,39 +10,171 @@ from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, + setup_initial_state, + get_memory_allocations, ) -from jax.sharding import Mesh +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) print("latents.shape: ", latents.shape, latents.dtype) print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return noise_pred def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 + mesh = Mesh(devices_array, config.mesh_axes) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) - # load in model config + ##load in model config config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( # noqa F841 - transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + + # create dummy inputs: + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids + ) + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, ) + noise_pred = p_run_inference(states).block_until_ready() + print(noise_pred) # (4, 256, 128) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..f3f5148b2 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if config.ici_fsdp_transpose_parallelism: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -402,7 +417,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 75b16b011..c5b3c0ef9 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,4 +1,5 @@ { + "ckpt_path": "", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 67437ba0b..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -154,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..b0a266b70 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,198 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 0000000000000000000000000000000000000000..0a9fe912036cf35e35d5d8127cff178e1b4a9399 GIT binary patch literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C literal 0 HcmV?d00001 From 1c554523fbeddc9d3f85f59ed2e23c36a32356be Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:24:26 +0000 Subject: [PATCH 06/34] removed diffusers import --- .../models/ltx_video/transformers/activations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 4a78b48ea..8e7ffb321 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -5,8 +5,6 @@ from flax import linen as nn from flax.linen.initializers import lecun_normal -from diffusers.utils.deprecation_utils import deprecate - from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer @@ -117,9 +115,9 @@ class GEGLU(nn.Module): @nn.compact def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) proj = DenseGeneral( features=self.dim_out * 2, From fd4af91ddb23b73dd7fb58958f49fb24bd938db2 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:30:45 +0000 Subject: [PATCH 07/34] fixed mesh --- src/maxdiffusion/max_utils.py | 63 +++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index f3f5148b2..c48b7da0f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True): devices = jax.devices() num_devices = len(devices) ##special case for ltx-video - if config.ici_fsdp_transpose_parallelism: + if "fsdp_transpose" in config.mesh_axes: num_slices = 1 # if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -628,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file From 5e17a62648b1d5d26198b3b1adea6cb89e0eb904 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 01:05:14 +0000 Subject: [PATCH 08/34] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index b0a266b70..d0f6c2e1d 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,7 @@ def load_ref_prediction(): - saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From fc60b27b30e2988a4e4fad2d0252a0f044ded270 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 03:44:05 +0000 Subject: [PATCH 09/34] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d0f6c2e1d..d40c932ba 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,8 @@ def load_ref_prediction(): - saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From 3243535216e0348c36ddd7c3e90316c66343026d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:33:16 +0000 Subject: [PATCH 10/34] changed config path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d40c932ba..43b15ba72 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -103,7 +103,9 @@ def test_one_step_transformer(self): key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From e873a17c795037318b46222c86137bb7899bb43d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:44:51 +0000 Subject: [PATCH 11/34] ruff check --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 43b15ba72..61c6909c0 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -105,7 +105,7 @@ def test_one_step_transformer(self): mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From d06dee301231259fa83f3e7849d4af0f8655d8bf Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 17:55:48 +0000 Subject: [PATCH 12/34] changed back pyconfig --- src/maxdiffusion/max_utils.py | 78 +---------------------------------- src/maxdiffusion/pyconfig.py | 41 ++++++++++-------- 2 files changed, 24 insertions(+), 95 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +288,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -474,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..f4e1900aa 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): @@ -221,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) + r = range(config.steps) \ No newline at end of file From aa7befd137cbb9bd28db098d23af7ab75de37b5d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 21:38:05 +0000 Subject: [PATCH 13/34] changed sharding back --- src/maxdiffusion/configs/ltx_video.yml | 8 +++++--- src/maxdiffusion/max_utils.py | 2 +- .../models/ltx_video/transformers/attention.py | 13 ++++++++++--- src/maxdiffusion/pyconfig.py | 2 +- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 8f1ee8a7d..d29707537 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -40,20 +40,22 @@ output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..fab895f97 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -609,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4812b89ba..e9d9d932d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -622,14 +622,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + ("data", "fsdp", "tensor"), None, None, ) # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("fsdp", None) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index f4e1900aa..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -226,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) \ No newline at end of file + r = range(config.steps) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 61c6909c0..9a816d6e5 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -104,7 +104,7 @@ def test_one_step_transformer(self): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) From d9a35020c13df74bd91ac3e74a111d8a1e82673d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sat, 5 Jul 2025 23:07:45 +0000 Subject: [PATCH 14/34] removed testing for now --- src/maxdiffusion/max_utils.py | 61 +----- .../tests/ltx_transformer_step_test.py | 201 ------------------ .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 3 files changed, 2 insertions(+), 260 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..d4a80a347 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 61c6909c0..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From a1ad421eda51fa42d34bc4667ff32f107d9d36e8 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:25 -0700 Subject: [PATCH 15/34] Update pyconfig.py --- src/maxdiffusion/pyconfig.py | 39 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index edcf96164..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,7 +25,6 @@ import yaml from . import max_logging from . import max_utils -from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -42,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -103,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) - _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -112,26 +125,6 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict - @staticmethod - def wan_init(raw_keys): - if "wan_transformer_pretrained_model_name_or_path" in raw_keys: - transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] - if transformer_pretrained_model_name_or_path == "": - raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif ( - transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH - or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH - ): - # Set correct parameters for CausVid in case of user error. - raw_keys["guidance_scale"] = 1.0 - num_inference_steps = raw_keys["num_inference_steps"] - if num_inference_steps > 10: - max_logging.log( - f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." - ) - else: - raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") - @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -176,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From 615174f94f66ebd5e19c98d1cac2b0bf242b70de Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:57 -0700 Subject: [PATCH 16/34] Update max_utils.py --- src/maxdiffusion/max_utils.py | 78 ++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..51de312ca 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if "fsdp_transpose" in config.mesh_axes: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -402,7 +474,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( From 7469c62c97f9f711114111490f9220a433b18d8a Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:26 -0700 Subject: [PATCH 17/34] Update ltx_video.yml --- src/maxdiffusion/configs/ltx_video.yml | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index d29707537..87d0e9bea 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,47 +22,32 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' -run_name: '' -output_dir: 'ltx-video-output' -save_config_to_gcs: False - -#hardware -hardware: 'tpu' -skip_jax_distributed_system: False - -jax_cache_dir: '' -weights_dtype: 'bfloat16' -activations_dtype: 'bfloat16' - - run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 - ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 + ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_tensor_transpose_parallelism: 1 @@ -84,4 +69,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True -enable_single_replica_ckpt_restoring: False \ No newline at end of file +enable_single_replica_ckpt_restoring: False From 6de4424170129d21fb4c11b38a62cccfc8e9a0d2 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:51 -0700 Subject: [PATCH 18/34] Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred --- .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From 18ec247aa0b181f1ded5642093027d1cce109b3e Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:17:01 -0700 Subject: [PATCH 19/34] Delete src/maxdiffusion/tests/ltx_transformer_step_test.py --- .../tests/ltx_transformer_step_test.py | 201 ------------------ 1 file changed, 201 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 9a816d6e5..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() From 546ecab301c4e321d78bb2464a0df936c55beecb Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 00:20:43 +0000 Subject: [PATCH 20/34] ruff fixed --- src/maxdiffusion/generate_ltx_video.py | 8 +--- src/maxdiffusion/max_utils.py | 23 +---------- src/maxdiffusion/pyconfig.py | 39 +++++++++++-------- .../tests/ltx_transformer_step_test.py | 2 +- 4 files changed, 27 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 371d309e3..fa495ba1a 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -64,10 +64,6 @@ def run_inference( segment_ids=segment_ids, encoder_attention_segment_ids=encoder_attention_segment_ids, ) - prof = profiler.Profiler(config) - prof.activate(optional_postfix="transformer step") - prof.deactivate() - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) @@ -176,8 +172,8 @@ def run(config): in_shardings=(state_shardings,), out_shardings=None, ) - with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): - noise_pred = p_run_inference(states).block_until_ready() + + noise_pred = p_run_inference(states).block_until_ready() print(noise_pred) # (4, 256, 128) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index d4a80a347..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -417,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 2555b330c..9398c9156 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -191,7 +191,7 @@ def test_one_step_transformer(self): in_shardings=(state_shardings,), out_shardings=None, ) - + noise_pred = p_run_inference(states).block_until_ready() noise_pred = torch.from_numpy(np.array(noise_pred)) From 12a247fe9fa71af862a69644ea4275ec7c7d791d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 17:39:09 +0000 Subject: [PATCH 21/34] added header --- .../models/ltx_video/transformers/activations.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/adaln.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/attention.py | 16 ++++++++++++++++ .../ltx_video/transformers/caption_projection.py | 16 ++++++++++++++++ .../ltx_video/transformers/transformer3d.py | 16 ++++++++++++++++ 5 files changed, 80 insertions(+) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 8e7ffb321..4ae1d9a00 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 4bc27e8bc..e9b287649 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Dict, Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 8a7541263..9faab1ded 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from functools import partial import math from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index f2b1af101..d8240989c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from flax import linen as nn import jax.numpy as jnp diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index cf599f26c..d6c7cf4c4 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import List, Optional, Tuple import jax From 1062c72f3f4a9429404154d51d3b9081e87fe762 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:11:17 +0000 Subject: [PATCH 22/34] license headers --- .github/workflows/UnitTests.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/__init__.py | 15 +++++++++++++++ .../models/ltx_video/gradient_checkpoint.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/linear.py | 16 ++++++++++++++++ .../models/ltx_video/repeatable_layer.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/__init__.py | 15 +++++++++++++++ 7 files changed, 95 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..c1fa771d1 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index fa495ba1a..2dec16fa6 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + from absl import app from typing import Sequence import jax diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index ef8c530ba..ee7221652 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto from typing import Optional diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index 31b21cdd9..3503ab3b4 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Union, Iterable, Tuple, Optional, Callable import numpy as np diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index aaed41048..7e9cc80c4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ From 535c75eea64044e3b87fce1b4ebcee52b80600a1 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:34:49 +0000 Subject: [PATCH 23/34] exclude test --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c1fa771d1..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: From 7af151ab5755b744717e1e0e37c8de7e60492004 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:20:30 +0000 Subject: [PATCH 24/34] change base branch --- src/maxdiffusion/configs/ltx_video.yml | 94 +- src/maxdiffusion/generate_ltx_video.py | 370 ++-- .../ltx_video/transformers/attention.py | 1680 +++++++++-------- .../ltx_video/transformers/transformer3d.py | 560 +++--- 4 files changed, 1397 insertions(+), 1307 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 140fda56c..38e9012e8 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,18 +1,3 @@ -# Copyright 2025 Google LLC - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# https://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - #hardware hardware: 'tpu' skip_jax_distributed_system: False @@ -25,30 +10,91 @@ activations_dtype: 'bfloat16' run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False +#Checkpoints +text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +frame_rate: 30 +max_sequence_length: 512 +sampler: "from_checkpoint" + + -#parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] + + +# Generation parameters +pipeline_type: multi-scale +prompt: ["A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie.", "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."] +height: 512 +width: 512 +num_frames: 88 #344 +flow_shift: 5.0 +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +prompt_enhancement_words_threshold: 120 +# guidance_scale: [1, 1, 6, 8, 6, 1, 1] #4.5 +# stg_scale: [0, 0, 4, 4, 4, 2, 1] #1.0 +# rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] #0.7 +# num_inference_steps: 30 +# skip_final_inference_steps: 3 +# skip_initial_inference_steps: 0 +# guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] +# skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] +stg_mode: "attention_values" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +# cfg_star_rescale: True + + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + skip_initial_inference_steps: 0 + cfg_star_rescale: True + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + skip_final_inference_steps: 0 + cfg_star_rescale: True + +#Parallelism +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded + +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 @@ -65,4 +111,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True -enable_single_replica_ckpt_restoring: False +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 2dec16fa6..90c82747c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,202 +1,196 @@ -""" - Copyright 2025 Google LLC +import numpy as np +from absl import app +from typing import Sequence +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline +from maxdiffusion import pyconfig +import jax.numpy as jnp +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from huggingface_hub import hf_hub_download +import imageio +from datetime import datetime +from maxdiffusion.utils import export_to_video - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +import os +import json +import torch +from pathlib import Path - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: -from absl import app -from typing import Sequence -import jax -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import os -import functools -import jax.numpy as jnp -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import Mesh, PartitionSpec as P -import orbax.checkpoint as ocp - - -def validate_transformer_inputs( - prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids -): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) - print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return noise_pred + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding -def run(config): - key = jax.random.PRNGKey(42) - - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - base_dir = os.path.dirname(__file__) - - ##load in model config - config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - - # create dummy inputs: - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join( + char.lower() for char in text if char.isalpha() or char.isspace() ) - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - validate_transformer_inputs( - prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids - ) - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - - noise_pred = p_run_inference(states).block_until_ready() - print(noise_pred) # (4, 256, 128) + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + +def create_latent_upsampler(latent_upsampler_model_path: str, device: str): + latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) + latent_upsampler.to(device) + latent_upsampler.eval() + return latent_upsampler + +def get_unique_filename( + base: str, + ext: str, + prompt: str, + seed: int, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + for i in range(index_range): + filename = dir / \ + f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError( + f"Could not find a unique filename after {index_range} attempts." + ) + + +def run(config): + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding( + config.height, config.width, height_padded, width_padded) + # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold + # prompt_word_count = len(config.prompt.split()) + # enhance_prompt = ( + # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold + # ) + + seed = 10 # change this, generator in pytorch, used in prepare_latents + generator = torch.Generator().manual_seed(seed) + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) + if config.pipeline_type == "multi-scale": #move this to pipeline file?? + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile( + spatial_upscaler_model_name_or_path + ): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir= "/mnt/disks/diffusionproj", + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = create_latent_upsampler( + spatial_upscaler_model_path, "cpu" #device set to cpu for now + ) + pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) + stg_mode = config.stg_mode + if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": + skip_layer_strategy = SkipLayerStrategy.AttentionValues + elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": + skip_layer_strategy = SkipLayerStrategy.AttentionSkip + elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": + skip_layer_strategy = SkipLayerStrategy.Residual + elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": + skip_layer_strategy = SkipLayerStrategy.TransformerBlock + else: + raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") + # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, + # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, + # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images + images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config) + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, :config.num_frames, + pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + print(output_filename) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 9faab1ded..fba1cdfc3 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -15,10 +15,10 @@ # This implementation is based on the Torch version available at: # https://github.com/Lightricks/LTX-Video/tree/main from functools import partial +import functools import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto - import jax import jax.nn as jnn import jax.numpy as jnp @@ -40,876 +40,898 @@ ) + class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - - def __call__(self, x): - return x + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] - - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) - - attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) - - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states - -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) - - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", dtype=self.dtype, weight_dtype=self.weight_dtype, - name="to_out.0", matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any(device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print("Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 - assert cross_attention_kwargs.get("scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") - - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") - - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") - - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" - - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) - - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - - hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) - - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") - ) - - hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states + ) + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") -class AttentionOp(nn.Module): + batch_size = hidden_states.shape[0] - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # ("data", "fsdp", "fsdp_transpose", "expert"), - # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - # None, - # None, - # ) - qkvo_sharding_spec = jax.sharding.PartitionSpec( - "data", - "fsdp", - None, - "tensor", - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) - - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index = index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) - ** SRAM cache size for TPU - V5P - 1MB SRAM per core + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + block_index = -1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) -class ExplicitAttention(nn.Module): + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) #here skip_layer_mask is (48,3), changed this currently! + + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ - - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) - Returns: - jax.Array: Normed data - """ + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + None, + None, + None, + None, + ) + qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - return hidden_states +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError(f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out + return out \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index d6c7cf4c4..5213beb7d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -29,277 +29,305 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( # noqa C408 - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def init_weights(self, in_channels, key, caption_channels, eval_only=True): - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "hidden_states": (batch_size, num_tokens, in_channels), - "indices_grid": (batch_size, 3, num_tokens), - "encoder_hidden_states": (batch_size, 128, caption_channels), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - if eval_only: - return jax.eval_shape( - self.init, - key, - **example_inputs, - )["params"] - else: - return self.init(key, **example_inputs)["params"] - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - - if self.num_layers > 0: - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - ) - # Output processing - - scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + def init_weights(self, in_channels, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + jax.random.PRNGKey(42), ##need to change? + **example_inputs, + )["params"] + else: + return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones( + (self.num_layers, batch_size * num_conds), dtype=self.dtype + ) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + - Args: - x (jax.Array): log value - base (jax.Array): base of the log - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - dtype = jnp.float32 # We need full precision in the freqs_cis computation. - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) From 634591b86ab7ef1fda2c3a7ac803b37ce0cd80ee Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:46:25 +0000 Subject: [PATCH 25/34] save now --- src/maxdiffusion/configs/ltx_video.yml | 23 +- .../models/ltx_video/autoencoders/__init__.py | 0 .../ltx_video/autoencoders/causal_conv3d.py | 63 + .../autoencoders/causal_video_autoencoder.py | 1398 +++++++++++++++++ .../ltx_video/autoencoders/conv_nd_factory.py | 90 ++ .../ltx_video/autoencoders/dual_conv3d.py | 217 +++ .../autoencoders/latent_upsampler.py | 203 +++ .../ltx_video/autoencoders/pixel_norm.py | 12 + .../ltx_video/autoencoders/pixel_shuffle.py | 33 + .../models/ltx_video/autoencoders/vae.py | 380 +++++ .../ltx_video/autoencoders/vae_encode.py | 247 +++ .../autoencoders/video_autoencoder.py | 1045 ++++++++++++ .../ltx_video/transformers/attention.py | 9 +- .../transformers/symmetric_patchifier.py | 84 + .../ltx_video/transformers/transformer3d.py | 6 +- .../utils/diffusers_config_mapping.py | 174 ++ .../ltx_video/utils/prompt_enhance_utils.py | 226 +++ .../ltx_video/utils/skip_layer_strategy.py | 8 + .../models/ltx_video/utils/torch_utils.py | 25 + .../pipelines/ltx_video/__init__.py | 0 .../pipelines/ltx_video/ltx_video_pipeline.py | 1374 ++++++++++++++++ .../schedulers/scheduling_rectified_flow.py | 357 +++++ 22 files changed, 5953 insertions(+), 21 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py create mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_utils.py create mode 100644 src/maxdiffusion/pipelines/ltx_video/__init__.py create mode 100644 src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py create mode 100644 src/maxdiffusion/schedulers/scheduling_rectified_flow.py diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 38e9012e8..cba635a1a 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -24,7 +24,7 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: ["A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie.", "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."] +prompt: "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage." height: 512 width: 512 num_frames: 88 #344 @@ -68,34 +68,29 @@ second_pass: skip_final_inference_steps: 0 cfg_star_rescale: True -#Parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 - -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 -ici_fsdp_transpose_parallelism: 1 -ici_sequence_parallelism: 1 -ici_tensor_transpose_parallelism: 1 -ici_expert_parallelism: 1 -ici_sequence_parallelism: 1 - diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py new file mode 100644 index 000000000..98249c2f5 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py @@ -0,0 +1,63 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py new file mode 100644 index 000000000..1255b6d34 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1398 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper +from maxdiffusion.models.ltx_video.transformers.attention import Attention +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "CausalVideoAutoencoder" + ), "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError( + f"latent_log_var={latent_log_var} requires use_quant_conv=False" + ) + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels + // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_space", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_time_res", + ] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } + ckpt_state_dict = { + key: value + for key, value in state_dict.items() + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + + model_keys = set(name for name, _ in self.named_modules()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name.startswith("compress"): + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.last_scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError( + "attention_head_dim must be less than or equal to in_channels" + ) + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, frames * height * width + ).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad( + hidden_states, (0, 0, 0, pad_len), "constant", 0 + ) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=( + None if not attention.use_tpu_flash_attention else mask + ), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, frames, height, width + ) + else: + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = ( + np.prod(stride) * in_channels // out_channels_reduction_factor + ) + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_demo_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("compress_time_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape, timestep=timestep + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode( + image_latent, target_shape=image_latent.shape, timestep=timestep + ).sample + + first_frame_latent = latent[:, :, :1, :, :] + + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py new file mode 100644 index 000000000..1aa55ed9c --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py @@ -0,0 +1,90 @@ +from typing import Tuple, Union + +import torch + +from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d +from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", +): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py new file mode 100644 index 000000000..dcf889296 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py @@ -0,0 +1,217 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py new file mode 100644 index 000000000..8cb7d7d68 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py @@ -0,0 +1,203 @@ +from typing import Optional, Union +from pathlib import Path +import os +import json + +import torch +import torch.nn as nn +from einops import rearrange +from diffusers import ConfigMixin, ModelMixin +from safetensors.torch import safe_open + +from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler + + +if __name__ == "__main__": + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py new file mode 100644 index 000000000..9bc3ea60e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py new file mode 100644 index 000000000..4e79ae284 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from einops import rearrange + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py new file mode 100644 index 000000000..821a6b32b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py @@ -0,0 +1,380 @@ +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd( + quant_dims, 2 * latent_channels, 2 * latent_channels, 1 + ) + self.post_quant_conv = make_conv_nd( + quant_dims, latent_channels, latent_channels, 1 + ) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert ( + z_sample_size % 8 == 0 or z_sample_size == 1 + ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( + 1 - z / blend_extent + ) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = ( + sizes + [z.shape[2] - sum(sizes)] + if z.shape[2] - sum(sizes) > 0 + else sizes + ) + tiles = z.split(sizes, dim=2) + moments_tiles = [ + ( + self._hw_tiled_encode(z_tile, return_dict) + if self.use_hw_tiling + else self._encode(z_tile) + ) + for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = ( + self._hw_tiled_encode(z, return_dict) + if self.use_hw_tiling + else self._encode(z) + ) + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t + * 2 + ** ( + len(self.encoder.down_blocks) + - 1 + - math.sqrt(self.encoder.patch_size) + ) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode(z, target_shape) + if self.use_hw_tiling + else self._decode(z, target_shape=target_shape, timestep=timestep) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py new file mode 100644 index 000000000..5a0aeeccf --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py @@ -0,0 +1,247 @@ +from typing import Tuple +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + + +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + latents = vae.encode(media_items).latent_dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len( + [ + block + for block in vae.encoder.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + spatial = vae.config.patch_size * 2**down_blocks + temporal = ( + vae.config.patch_size_t * 2**down_blocks + if isinstance(vae, VideoAutoencoder) + else 1 + ) + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords( + latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors( + latent_coords, scale_factors, causal_fix + ) + return pixel_coords + + +def latent_to_pixel_coords_from_factors( + latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False +) -> Tensor: + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py new file mode 100644 index 000000000..5b9ea640b --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py @@ -0,0 +1,1045 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from maxdiffusion.models.ltx_video.utils.torch_utils import Identity +from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm +from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "VideoAutoencoder" + ), "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels + // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels + // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels + for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + * self.patch_size + ) + + def forward( + self, sample: torch.FloatTensor, return_features=False + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)( + sample, downsample_in_time=downsample_in_time + ) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block + and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, block_out_channels[0], out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward( + self, hidden_states: torch.FloatTensor, downsample_in_time + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample( + hidden_states, downsample_in_time=downsample_in_time + ) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D( + dims=dims, channels=out_channels, out_channels=out_channels + ) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, upsample_in_time=True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd( + dims, channels, out_channels, kernel_size=3, padding=1, bias=True + ) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange( + x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d + ) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index fba1cdfc3..caff27add 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -670,12 +670,13 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): # ) # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") qkvo_sharding_spec = jax.sharding.PartitionSpec( - None, - None, - None, + "data", + "fsdp", None, + "tensor", ) - qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py new file mode 100644 index 000000000..2eca32033 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 5213beb7d..7d068b0f6 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -153,7 +153,7 @@ def scale_shift_table_init(key): weight_dtype=self.weight_dtype, matmul_precision=self.matmul_precision, ) - def init_weights(self, in_channels, caption_channels, eval_only=True): + def init_weights(self, key, in_channels, caption_channels, eval_only=True): example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { @@ -172,11 +172,11 @@ def init_weights(self, in_channels, caption_channels, eval_only=True): if eval_only: return jax.eval_shape( self.init, - jax.random.PRNGKey(42), ##need to change? + key, ##need to change? **example_inputs, )["params"] else: - return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] + return self.init(key, **example_inputs)['params'] def create_skip_layer_mask( self, diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 000000000..53c0082d1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py new file mode 100644 index 000000000..901051728 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py @@ -0,0 +1,226 @@ +import logging +from typing import Union, List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + +I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Keep within 150 words. +For best results, build your prompts using this structure: +Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Align to the image caption if it contradicts the user text input. +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + + +def tensor_to_pil(tensor): + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 + + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 + + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) + + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + + # Convert to PIL Image + return Image.fromarray(numpy_image) + + +def generate_cinematic_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompt: Union[str, List[str]], + conditioning_items: Optional[List] = None, + max_new_tokens: int = 256, +) -> List[str]: + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item( + first_frame_conditioning_item + ) + + assert len(first_frames) == len( + prompts + ), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts + + +def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: + frames_tensor = conditioning_item.media_item + return [ + tensor_to_pil(frames_tensor[i, :, 0, :, :]) + for i in range(frames_tensor.shape[0]) + ] + + +def _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + first_frames: List[Image.Image], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + image_captions = _generate_image_captions( + image_caption_model, image_caption_processor, first_frames + ) + + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_image_captions( + image_caption_model, + image_caption_processor, + images: List[Image.Image], + system_prompt: str = "", +) -> List[str]: + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor( + image_caption_prompts, images, return_tensors="pt" + ).to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int +) -> List[str]: + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate( + **model_inputs, max_new_tokens=max_new_tokens + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, outputs) + ] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + ) + + return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 000000000..30f9016e1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py new file mode 100644 index 000000000..991b07c36 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/src/maxdiffusion/pipelines/ltx_video/__init__.py b/src/maxdiffusion/pipelines/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py new file mode 100644 index 000000000..d0a4ea6da --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -0,0 +1,1374 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +import os +import random +from jax import Array +from datetime import datetime +from pathlib import Path +from diffusers import AutoencoderKL +from typing import Optional, List, Union, Tuple +from einops import rearrange +import torch.nn.functional as F +from diffusers.utils.torch_utils import randn_tensor +# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +import yaml +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, + T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) + + +import imageio +import json +import numpy as np +import torch +from safetensors import safe_open +from PIL import Image +from transformers import ( + T5EncoderModel, + T5Tokenizer, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) +from huggingface_hub import hf_hub_download +from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + vae_decode, + vae_encode, + un_normalize_latents, + normalize_latents, +) +from diffusers.image_processor import VaeImageProcessor +from ltx_video.schedulers.rf import RectifiedFlowScheduler +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +import ltx_video.pipelines.crf_compressor as crf_compressor +from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from math import e +from types import NoneType +from typing import Any, Dict +import numpy as np +import inspect + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec as P +from typing import Optional, Union, List +import torch +from maxdiffusion.checkpointing import checkpointing_utils +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ...pyconfig import HyperParameters +# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState +from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState +from ...max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations +) +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import json +import functools +import orbax.checkpoint as ocp +import pickle + + +class PickleCheckpointHandler(ocp.CheckpointHandler): + def save(self, directory: str, item, args=None): + os.makedirs(directory, exist_ok=True) + with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: + pickle.dump(item, f) + + def restore(self, directory: str, args=None): + with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: + return pickle.load(f) + + def structure(self, directory: str): + return {} # not needed for simple pickle-based handling + + +def save_tensor_dict(tensor_dict, timestep): + base_dir = os.path.dirname(__file__) + local_path = os.path.join(base_dir, f"schedulerTest{timestep}") + + try: + torch.save(tensor_dict, local_path) + print(f"Dictionary of tensors saved to: {local_path}") + except Exception as e: + print(f"Error saving dictionary: {e}") + raise + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", + encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def prepare_extra_step_kwargs(generator): + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + +class LTXVideoPipeline: + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: AutoencoderKL, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( + self.vae + ) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + # "Wan-AI/Wan2.1-T2V-14B-Diffusers", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( + # "Lightricks/LTX-Video", + # subfolder="scheduler", + # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p + # ) + # import pdb; pdb.set_trace() + # scheduler = FlaxRectifiedFlowMultistepScheduler( + # sampler="LinearQuadratic" + # ) + # scheduler_state = scheduler.create_state() + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", + "causal_temporal_positioning", "in_channels", "ckpt_path"] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + transformer = Transformer3DModel( + # change this sharding back + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) + + weights_init_fn = functools.partial( + transformer.init_weights, + jax.random.PRNGKey(42), + in_channels, + model_config['caption_channels'], + eval_only=True + ) + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put( + transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + return vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + # text_encoder = T5EncoderModel.from_pretrained( + # ckpt_path, subfolder="text_encoder" + # ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + # tokenizer = T5Tokenizer.from_pretrained( + # ckpt_path, subfolder="tokenizer" + # ) + t5_tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, max_length=config.max_sequence_length, use_fast=True + ) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer( + config) + + # load from pytorch version + models_dir = "/mnt/disks/diffusionproj" #edit this! + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + vae = cls.load_vae(ltxv_model_path) + vae = vae.to(torch.bfloat16) + text_encoder = cls.load_text_encoder( + config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer( + config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( + config) + else: + prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + """ + Perform the denoising step for the required tokens, based on the current timestep and + conditioning mask: + Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) + and will start to be denoised when the current timestep is equal or lower than their + conditioning timestep. + (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) + """ + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( #currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps # Use the original num_inference_steps here for the check + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[ + skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + ] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) + + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = ( + text_encoder_max_tokens # TPU supports only lengths multiple of 128 + ) + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + # text_enc_device = next(self.text_encoder.parameters()) + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, max_length - 1: -1] + ) + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + print(isinstance(prompt_embeds, Array)) + # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + # prompt_embeds = prompt_embeds.view( + # bs_embed * num_images_per_prompt, seq_len, -1 + # ) + prompt_embeds = jnp.reshape( + prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile( + prompt_attention_mask, (1, num_images_per_prompt)) + # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + # prompt_attention_mask = prompt_attention_mask.view( + # bs_embed * num_images_per_prompt, -1 + # ) + prompt_attention_mask = jnp.reshape( + prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance hasn't changed yet + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, + (1, num_images_per_prompt, 1) + ) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, + (batch_size * num_images_per_prompt, seq_len, -1) + ) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, + (1, num_images_per_prompt) + ) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, + (bs_embed * num_images_per_prompt, -1) + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds,#(1, 256, 4096) + prompt_attention_mask, #1, 256 + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items, + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert ( + latents.shape == latent_shape + ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor( + (b, f * h * w, c), generator=generator, device=device, dtype=dtype + ) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, CausalVideoAutoencoder) + + # if conditioning_items: + # batch_size, _, num_latent_frames = init_latents.shape[:3] + + # init_conditioning_mask = torch.zeros( + # init_latents[:, 0, :, :, :].shape, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents = [] + # extra_conditioning_pixel_coords = [] + # extra_conditioning_mask = [] + # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # # Process each conditioning item + # for conditioning_item in conditioning_items: + # conditioning_item = self._resize_conditioning_item( + # conditioning_item, height, width + # ) + # media_item = conditioning_item.media_item + # media_frame_number = conditioning_item.media_frame_number + # strength = conditioning_item.conditioning_strength + # assert media_item.ndim == 5 # (b, c, f, h, w) + # b, c, n_frames, h, w = media_item.shape + # assert ( + # height == h and width == w + # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + # assert n_frames % 8 == 1 + # assert ( + # media_frame_number >= 0 + # and media_frame_number + n_frames <= num_frames + # ) + + # # Encode the provided conditioning media item + # media_item_latents = vae_encode( + # media_item.to(dtype=self.vae.dtype, device=self.vae.device), + # self.vae, + # vae_per_channel_normalize=vae_per_channel_normalize, + # ).to(dtype=init_latents.dtype) + + # # Handle the different conditioning cases + # if media_frame_number == 0: + # # Get the target spatial position of the latent conditioning item + # media_item_latents, l_x, l_y = self._get_latent_spatial_position( + # media_item_latents, + # conditioning_item, + # height, + # width, + # strip_latent_border=True, + # ) + # b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # # First frame or sequence - just update the initial noise latents and the mask + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + # torch.lerp( + # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + # media_item_latents, + # strength, + # ) + # ) + # init_conditioning_mask[ + # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + # ] = strength + # else: + # # Non-first frame or sequence + # if n_frames > 1: + # # Handle non-first sequence. + # # Encoded latents are either fully consumed, or the prefix is handled separately below. + # ( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # ) = self._handle_non_first_conditioning_sequence( + # init_latents, + # init_conditioning_mask, + # media_item_latents, + # media_frame_number, + # strength, + # ) + + # # Single frame or sequence-prefix latents + # if media_item_latents is not None: + # noise = randn_tensor( + # media_item_latents.shape, + # generator=generator, + # device=media_item_latents.device, + # dtype=media_item_latents.dtype, + # ) + + # media_item_latents = torch.lerp( + # noise, media_item_latents, strength + # ) + + # # Patchify the extra conditioning latents and calculate their pixel coordinates + # media_item_latents, latent_coords = self.patchifier.patchify( + # latents=media_item_latents + # ) + # pixel_coords = latent_to_pixel_coords( + # latent_coords, + # self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, + # ) + + # # Update the frame numbers to match the target frame number + # pixel_coords[:, 0] += media_frame_number + # extra_conditioning_num_latents += media_item_latents.shape[1] + + # conditioning_mask = torch.full( + # media_item_latents.shape[:2], + # strength, + # dtype=torch.float32, + # device=init_latents.device, + # ) + + # extra_conditioning_latents.append(media_item_latents) + # extra_conditioning_pixel_coords.append(pixel_coords) + # extra_conditioning_mask.append(conditioning_mask) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify( + latents=init_latents + ) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True + + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + # init_conditioning_mask, _ = self.patchifier.patchify( + # latents=init_conditioning_mask.unsqueeze(1) + # ) + # init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + # if extra_conditioning_latents: + # # Stack the extra conditioning latents, pixel coordinates and mask + # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + # init_pixel_coords = torch.cat( + # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + # ) + # init_conditioning_mask = torch.cat( + # [*extra_conditioning_mask, init_conditioning_mask], dim=1 + # ) + + # if self.transformer.use_tpu_flash_attention: + # # When flash attention is used, keep the original number of tokens by removing + # # tokens from the end. + # init_latents = init_latents[:, :-extra_conditioning_num_latents] + # init_pixel_coords = init_pixel_coords[ + # :, :, :-extra_conditioning_num_latents + # ] + # init_conditioning_mask = init_conditioning_mask[ + # :, :-extra_conditioning_num_latents + # ] + + # return ( + # init_latents, + # init_pixel_coords, + # init_conditioning_mask, + # extra_conditioning_num_latents, + # ) + + # change the paramters of these, currently pass in dummy inputs + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get( + "vae_per_channel_normalize", True) + import pdb; pdb.set_trace() + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join( + base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) + + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [ + i for i, val in enumerate(guidance_timesteps) if val <= timestep + ] + guidance_mapping.append( + indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) + ) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask( + batch_size, num_conds, num_conds - 1, skip_blocks + ) + for skip_blocks in skip_block_list + ] + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate( + [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension + ) + prompt_attention_mask_batch = jnp.concatenate( + [negative_prompt_attention_mask, prompt_attention_mask], axis=0 + ) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = ( + self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + ) + + extra_step_kwargs = prepare_extra_step_kwargs( + generator=jax.random.PRNGKey(0)) + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + # if not isinstance(guidance_scale, List): + # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) + # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + # latents = jax.device_put(example_inputs["latents"], data_sharding) + # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + noise_cond = jnp.ones( # initialize first round with this! + (1, 1) + ) + + # # noise_cond = None + # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" + # tensor_dict = torch.load(saved_tensor_path) + + # for key, value in tensor_dict.items(): + # if value is not None: + # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) + # example_inputs = tensor_dict + # latents = jax.device_put(example_inputs["latent_model_input"]) + # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) + # fractional_coords = jax.device_put(example_inputs["indices_grid"]) + # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) + # segment_ids = None + # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) + + # #only run this for the first time! + # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) + # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken + # scheduler_state = self.scheduler_state + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = jax.jit( + # functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # ), + # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct + # out_shardings=None, + # ) + segment_ids = None + # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here + # p_run_inference = functools.partial( + # run_inference, + # transformer=self.transformer, + # config=self.config, + # mesh=self.mesh, + # fractional_cords=fractional_coords, + # prompt_embeds = prompt_embeds, + # segment_ids=segment_ids, + # encoder_attention_segment_ids=encoder_attention_segment_ids, + # num_inference_steps=num_inference_steps, + # scheduler=self.scheduler, + # # guidance_scale=guidance_scale + # ) + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array( + fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling = do_rescaling, + rescaling_scale = rescaling_scale, + batch_size=batch_size, + skip_layer_masks = skip_layer_masks, + skip_layer_strategy = skip_layer_strategy, + cfg_star_rescale = cfg_star_rescale + ) + + with self.mesh: + latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( + # add scheduler state back in + torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) + latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] + // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [ + decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor( + decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to( + latents.device + )[:, None, None, None, None] + latents = ( + latents * (1 - decode_noise_scale) + + noise * decode_noise_scale + ) + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs.get( + "vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + image = self.image_processor.postprocess( + image, output_type=output_type) # shape mismatch here + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + # save states here + + +def transformer_forward_pass( # need to jit this? wan didnt + latents, + state, + noise_cond, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask, + skip_layer_strategy, +): + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy + ) # need .param here? + return noise_pred, state + + +def run_inference( + transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale +): + # do_classifier_free_guidance = guidance_scale > 1.0 + # for step in range(num_inference_steps): + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = ( + jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + ) + # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) + # timestep = jnp.broadcast_to( + # t, (latent_model_input.shape[0],) + # ).reshape(-1, 1) + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + # Determine the correct dtype based on the device (similar to PyTorch) + is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension (compatible with ONNX/Core ML) + current_timestep = jnp.broadcast_to( + current_timestep, (latent_model_input.shape[0],1) + ) + + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( + skip_layer_masks[i] + if skip_layer_masks is not None + else None + ), skip_layer_strategy = skip_layer_strategy) + # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) + + # # latents = self.denoising + # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated + if do_spatio_temporal_guidance: + # JAX uses jnp.split for splitting along an axis. + # Equivalent to .chunk() for equal-sized chunks + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum( #(1, 1) + positive_flat * negative_flat, axis=1, keepdims=True + ) + squared_norm = ( #(1, 1) + jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + ) + alpha = dot_product / squared_norm #might need to reshape this (1, 1) + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) + noise_pred = noise_pred_uncond + guidance_scale[i] * ( + noise_pred_text - noise_pred_uncond + ) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] # JAX slicing is similar + latents, scheduler_state = scheduler.step( + scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + +def adain_filter_latent( + latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 +): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean( + reference_latents[i, c], dim=None + ) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + +class LTXMultiScalePipeline: + def _upsample_latents( + self, latest_upsampler: LatentUpsampler, latents: torch.Tensor + ): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents( + latents, self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents( + upsampled_latents, self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents + + def __init__( + self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + ): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + self.latent_upsampler = latent_upsampler + + def __call__( + self, + height, + width, + num_frames, + is_video, + output_type, + generator, + config + ) -> Any: + + original_output_type = output_type + original_width = width + original_height = height + x_width = int(width * config.downscale_factor) + downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) + x_height = int(height * config.downscale_factor) + downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) + #use original height and width here to see + output_type = 'latent' + result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, + is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], + num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) + latents = result + upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents + ) + + + + latents = upsampled_latents + output_type = original_output_type + width = downscaled_width * 2 + height = downscaled_height * 2 + + result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, + is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], + num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(original_height, original_width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result \ No newline at end of file diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py new file mode 100644 index 000000000..1624b81c4 --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -0,0 +1,357 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from pathlib import Path +import os +from safetensors import safe_open + +import flax +import jax +import jax.numpy as jnp +import json +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.utils import is_scipy_available +from maxdiffusion.schedulers.scheduling_utils_flax import ( + CommonSchedulerState, + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + +def linear_quadratic_schedule_jax(num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) + +def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + +def _prod_jax(iterable): + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + +def get_normal_shift_jax( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + + + +def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) + + return stretched_shifts + +def sd3_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + target_shift_terminal: Optional[float] = None, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) + + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts + + +def simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape: Tuple[int, ...], + timesteps: jnp.ndarray, + n: int = 32 * 32, +) -> jnp.ndarray: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + +@flax.struct.dataclass +class RectifiedFlowSchedulerState: + """ + Data class to hold the mutable state of the RectifiedFlowScheduler. + """ + + common: CommonSchedulerState + init_noise_sigma: float + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + + + + @classmethod + def create( #need to change this! + cls, + common_state: CommonSchedulerState, + init_noise_sigma: float + ): + return cls( + common = common_state, + init_noise_sigma = init_noise_sigma, + num_inference_steps = None, + timesteps = None, + sigmas = None, + ) + + +@dataclass +class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): + state: RectifiedFlowSchedulerState + + +class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + + dtype: jnp.dtype + order = 1 + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + beta_schedule: str = "linear", + rescale_zero_terminal_snr: bool = False, + beta_start: float = 0.0001, + beta_end: float = 0.02, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + sampler: Optional[str] = "Uniform", + shift: Optional[float] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + init_noise_sigma = 1.0 + return RectifiedFlowSchedulerState.create(common_state = common, init_noise_sigma=init_noise_sigma) + + + def get_initial_timesteps_jax( + self, num_timesteps: int, shift: Optional[float] = None + ) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax( + shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + ).astype(self.dtype) + else: + # This should be caught by __init__ but for safety + raise ValueError(f"Sampler {self.config.sampler} is not supported.") + + def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax( + samples_shape, timesteps, self.config.target_shift_terminal + ) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax( + samples_shape, timesteps, self.config.base_resolution + ) + return timesteps + + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + config = None + if pretrained_model_path.is_file(): + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata['config']) + config = configs["scheduler"] + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = ( + pretrained_model_path / "scheduler" / "scheduler_config.json" + ) + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError( + f"Scheduler config not found at {diffusers_noise_scheduler_config_path}" + ) + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config + return FlaxRectifiedFlowMultistepScheduler.from_config(config) + + def set_timesteps( + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError( + "You cannot provide both `timesteps` and `num_inference_steps`." + ) + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum( + self.config.num_train_timesteps, num_inference_steps + ) + timesteps = self.get_initial_timesteps_jax( + num_inference_steps, shift=self.config.shift + ).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + + def scale_model_input( + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + + + def step( + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, # Can be global or per-token, but for RF it's typically global. + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side='right') + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), + timesteps_padded[current_t_idx + 1], + 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side='right') # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) From a272d08f06ffb602c11572fbe4d6ab468aeb3e28 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 10 Jul 2025 22:57:03 +0000 Subject: [PATCH 26/34] load transformer error --- src/maxdiffusion/models/ltx_video/transformers/transformer3d.py | 1 + src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 7d068b0f6..12f035031 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -154,6 +154,7 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) def init_weights(self, key, in_channels, caption_channels, eval_only=True): + import pdb; pdb.set_trace() example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index d0a4ea6da..422005b19 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -240,7 +240,7 @@ def load_transformer(cls, config): tx=None, config=config, mesh=mesh, - weights_init_fn=weights_init_fn, + weights_init_fn=None, checkpoint_manager=checkpoint_manager, checkpoint_item=" ", model_params=None, From 4bcffd11b5b5a2dffd2526080ccf092e5ad7eb4a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:01:04 +0000 Subject: [PATCH 27/34] later --- .../models/ltx_video/transformers/transformer3d.py | 10 ++++------ .../models/ltx_video/xora_v1.2-13B-balanced-128.json | 2 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 9 +++------ 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 12f035031..e3110a82b 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -153,8 +153,7 @@ def scale_shift_table_init(key): weight_dtype=self.weight_dtype, matmul_precision=self.matmul_precision, ) - def init_weights(self, key, in_channels, caption_channels, eval_only=True): - import pdb; pdb.set_trace() + def init_weights(self, in_channels, key, caption_channels, eval_only=True): example_inputs = {} batch_size, num_tokens = 4, 256 input_shapes = { @@ -169,16 +168,15 @@ def init_weights(self, key, in_channels, caption_channels, eval_only=True): example_inputs[name] = jnp.ones( shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool ) - + if eval_only: return jax.eval_shape( self.init, - key, ##need to change? + key, **example_inputs, )["params"] else: - return self.init(key, **example_inputs)['params'] - + return self.init(key, **example_inputs)["params"] def create_skip_layer_mask( self, batch_size: int, diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index c5b3c0ef9..bce38fb20 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,5 +1,5 @@ { - "ckpt_path": "", + "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 422005b19..5443a539d 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -226,12 +226,9 @@ def load_transformer(cls, config): **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) weights_init_fn = functools.partial( - transformer.init_weights, - jax.random.PRNGKey(42), - in_channels, - model_config['caption_channels'], - eval_only=True + transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) @@ -240,7 +237,7 @@ def load_transformer(cls, config): tx=None, config=config, mesh=mesh, - weights_init_fn=None, + weights_init_fn=weights_init_fn, checkpoint_manager=checkpoint_manager, checkpoint_item=" ", model_params=None, From f5afa917d55b1095eeef8a0e40252ed8e3c031a7 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:09:49 +0000 Subject: [PATCH 28/34] changed repeatable layer --- .../models/ltx_video/repeatable_layer.py | 161 +++++++++--------- .../pipelines/ltx_video/ltx_video_pipeline.py | 2 +- 2 files changed, 83 insertions(+), 80 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 7e9cc80c4..5e2ceb789 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,118 +1,121 @@ -# Copyright 2025 Lightricks Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This implementation is based on the Torch version available at: -# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional import jax from flax import linen as nn +import jax.numpy as jnp from flax.linen import partitioning class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - @nn.compact - def __call__(self, *args) -> Tuple[jax.Array, None]: - """ - jax carry-op format of block - assumes the input contains an input tensor to the block along with kwargs that might be send to the block - kwargs are assumed to have static role, while the input changes between cycles + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry - Returns: - Tuple[jax.Array, None]: Output tensor from the block - """ - mod = self.module(*self.module_init_args, **self.module_init_kwargs) - output = mod(*args) - return output, None + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + # block_args are the static arguments passed to each individual block + output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module + + next_index = index_input + 1 + new_carry = (output_data, next_index) + + + return new_carry, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} - - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, # Separate params per timestep - split_rngs={"params": True}, - in_axes=(nn.broadcast,) * (len(args) - 1), - length=self.num_layers, - **scan_kwargs, - ) - wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) - x, _ = wrapped_function(*args) - return x + @nn.compact + def __call__(self, *args): # args is now the full input to RepeatableLayer + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + + initial_data_input = args[0] # The first element is your main data input + static_block_args = args[1:] # Any subsequent elements are static args for each block + + initial_index = jnp.array(0, dtype=jnp.int32) + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + + # in_axes for the scanned function (RepeatableCarryBlock.__call__): + # 1. The 'carry' tuple ((0, 0)) + # 2. Then, nn.broadcast for each of the `static_block_args` + in_axes_for_scan = (nn.broadcast,) * (len(args)-1) + + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) + + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + + # Typically, you only want the final data output from the sequence of layers + return final_data \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 5443a539d..67ad161ce 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -228,7 +228,7 @@ def load_transformer(cls, config): weights_init_fn = functools.partial( transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) - + import pdb; pdb.set_trace() absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) From bb61ecb24287a2a6beae81634fa5a7422de53952 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 18:51:33 +0000 Subject: [PATCH 29/34] functional --- src/maxdiffusion/checkpointing/checkpointing_utils.py | 7 +++++-- src/maxdiffusion/max_utils.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..5072b4639 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..e13f31f94 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( From 7d4b2a9782ed4389341ce58e6d3abbcc02fa0083 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 22:02:55 +0000 Subject: [PATCH 30/34] moved upsampler --- .../pipelines/ltx_video/ltx_video_pipeline.py | 140 ++++++------------ 1 file changed, 45 insertions(+), 95 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 67ad161ce..35750ad5e 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -11,37 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import math import os -import random from jax import Array -from datetime import datetime -from pathlib import Path +from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from diffusers import AutoencoderKL from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor -# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -import yaml -from transformers import (CLIPTokenizer, FlaxCLIPTextModel, - T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) - - -import imageio -import json -import numpy as np -import torch -from safetensors import safe_open -from PIL import Image from transformers import ( - T5EncoderModel, - T5Tokenizer, + FlaxT5EncoderModel, + AutoTokenizer, AutoModelForCausalLM, AutoProcessor, - AutoTokenizer, -) + AutoTokenizer,) +import json +import numpy as np +import torch from huggingface_hub import hf_hub_download from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, @@ -55,27 +42,19 @@ normalize_latents, ) from diffusers.image_processor import VaeImageProcessor -from ltx_video.schedulers.rf import RectifiedFlowScheduler from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler -import ltx_video.pipelines.crf_compressor as crf_compressor from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt from math import e from types import NoneType from typing import Any, Dict import numpy as np -import inspect import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P from typing import Optional, Union, List -import torch -from maxdiffusion.checkpointing import checkpointing_utils -from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...pyconfig import HyperParameters -# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState from ...max_utils import ( create_device_mesh, @@ -83,50 +62,9 @@ get_memory_allocations ) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import os import json import functools import orbax.checkpoint as ocp -import pickle - - -class PickleCheckpointHandler(ocp.CheckpointHandler): - def save(self, directory: str, item, args=None): - os.makedirs(directory, exist_ok=True) - with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f: - pickle.dump(item, f) - - def restore(self, directory: str, args=None): - with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f: - return pickle.load(f) - - def structure(self, directory: str): - return {} # not needed for simple pickle-based handling - - -def save_tensor_dict(tensor_dict, timestep): - base_dir = os.path.dirname(__file__) - local_path = os.path.join(base_dir, f"schedulerTest{timestep}") - - try: - torch.save(tensor_dict, local_path) - print(f"Dictionary of tensors saved to: {local_path}") - except Exception as e: - print(f"Error saving dictionary: {e}") - raise - - -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) - # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) - print("encoder_attention_segment_ids.shape: ", - encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) - def prepare_extra_step_kwargs(generator): extra_step_kwargs = {} @@ -817,7 +755,6 @@ def __call__( skip_initial_inference_steps: int = 0, skip_final_inference_steps: int = 0, cfg_star_rescale: bool = False, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, ): @@ -1076,7 +1013,6 @@ def __call__( rescaling_scale = rescaling_scale, batch_size=batch_size, skip_layer_masks = skip_layer_masks, - skip_layer_strategy = skip_layer_strategy, cfg_star_rescale = cfg_star_rescale ) @@ -1149,7 +1085,6 @@ def transformer_forward_pass( # need to jit this? wan didnt segment_ids, encoder_attention_segment_ids, skip_layer_mask, - skip_layer_strategy, ): noise_pred = transformer.apply( {"params": state.params}, @@ -1160,13 +1095,12 @@ def transformer_forward_pass( # need to jit this? wan didnt segment_ids=segment_ids, encoder_attention_segment_ids=encoder_attention_segment_ids, skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy - ) # need .param here? + ) return noise_pred, state def run_inference( - transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale + transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale ): # do_classifier_free_guidance = guidance_scale > 1.0 # for step in range(num_inference_steps): @@ -1206,7 +1140,7 @@ def run_inference( skip_layer_masks[i] if skip_layer_masks is not None else None - ), skip_layer_strategy = skip_layer_strategy) + )) # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) # # latents = self.denoising @@ -1294,6 +1228,31 @@ def adain_filter_latent( return result class LTXMultiScalePipeline: + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile( + spatial_upscaler_model_name_or_path + ): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir= "/mnt/disks/diffusionproj", + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents( self, latest_upsampler: LatentUpsampler, latents: torch.Tensor ): @@ -1309,37 +1268,29 @@ def _upsample_latents( return upsampled_latents def __init__( - self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + self, video_pipeline: LTXVideoPipeline ): self.video_pipeline = video_pipeline self.vae = video_pipeline.vae - self.latent_upsampler = latent_upsampler - + def __call__( self, height, width, num_frames, - is_video, output_type, generator, config ) -> Any: + latent_upsampler = self.load_latent_upsampler(config) original_output_type = output_type - original_width = width - original_height = height - x_width = int(width * config.downscale_factor) - downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) - x_height = int(height * config.downscale_factor) - downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) - #use original height and width here to see output_type = 'latent' - result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames, + result = self.video_pipeline(height=height, width=width, num_frames=num_frames, is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], - num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"]) + num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_block_list=config.first_pass["skip_block_list"]) latents = result - upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = self._upsample_latents(latent_upsampler, latents) upsampled_latents = adain_filter_latent( latents=upsampled_latents, reference_latents=latents ) @@ -1348,12 +1299,11 @@ def __call__( latents = upsampled_latents output_type = original_output_type - width = downscaled_width * 2 - height = downscaled_height * 2 + - result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames, + result = self.video_pipeline(height=height*2, width=width*2, num_frames=num_frames, is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], - num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"]) + num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_block_list=config.second_pass["skip_block_list"]) if original_output_type != "latent": num_frames = result.shape[2] @@ -1361,7 +1311,7 @@ def __call__( videos = F.interpolate( videos, - size=(original_height, original_width), + size=(height, width), mode="bilinear", align_corners=False, ) From 972e316601fc896abfea9b9f81d49ce77ee037e8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Fri, 11 Jul 2025 23:19:45 +0000 Subject: [PATCH 31/34] initial cleaning --- src/maxdiffusion/configs/ltx_video.yml | 10 +- src/maxdiffusion/generate_ltx_video.py | 50 +- .../models/ltx_video/repeatable_layer.py | 15 +- .../pipelines/ltx_video/ltx_video_pipeline.py | 485 +++++++----------- 4 files changed, 191 insertions(+), 369 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index cba635a1a..e54216c72 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -32,18 +32,10 @@ flow_shift: 5.0 downscale_factor: 0.6666666 spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" prompt_enhancement_words_threshold: 120 -# guidance_scale: [1, 1, 6, 8, 6, 1, 1] #4.5 -# stg_scale: [0, 0, 4, 4, 4, 2, 1] #1.0 -# rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] #0.7 -# num_inference_steps: 30 -# skip_final_inference_steps: 3 -# skip_initial_inference_steps: 0 -# guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] -# skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] stg_mode: "attention_values" decode_timestep: 0.05 decode_noise_scale: 0.025 -# cfg_star_rescale: True +models_dir: "/mnt/disks/diffusionproj" #where safetensor file is first_pass: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 90c82747c..9ad816564 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,16 +4,12 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -import jax.numpy as jnp -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from huggingface_hub import hf_hub_download import imageio from datetime import datetime -from maxdiffusion.utils import export_to_video import os -import json import torch from pathlib import Path @@ -96,52 +92,12 @@ def run(config): num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 padding = calculate_padding( config.height, config.width, height_padded, width_padded) - # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold - # prompt_word_count = len(config.prompt.split()) - # enhance_prompt = ( - # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold - # ) - seed = 10 # change this, generator in pytorch, used in prepare_latents + seed = 10 generator = torch.Generator().manual_seed(seed) pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) - if config.pipeline_type == "multi-scale": #move this to pipeline file?? - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile( - spatial_upscaler_model_name_or_path - ): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir= "/mnt/disks/diffusionproj", - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = create_latent_upsampler( - spatial_upscaler_model_path, "cpu" #device set to cpu for now - ) - pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) - stg_mode = config.stg_mode - if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": - skip_layer_strategy = SkipLayerStrategy.AttentionValues - elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": - skip_layer_strategy = SkipLayerStrategy.AttentionSkip - elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": - skip_layer_strategy = SkipLayerStrategy.Residual - elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": - skip_layer_strategy = SkipLayerStrategy.TransformerBlock - else: - raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") - # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, - # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, - # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images - images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config) + pipeline = LTXMultiScalePipeline(pipeline) + images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config) (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom pad_right = -pad_right diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 5e2ceb789..06b367ce4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -25,8 +25,7 @@ def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tup mod = self.module(*self.module_init_args, **self.module_init_kwargs) - # block_args are the static arguments passed to each individual block - output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers next_index = index_input + 1 new_carry = (output_data, next_index) @@ -76,14 +75,14 @@ class RepeatableLayer(nn.Module): """ @nn.compact - def __call__(self, *args): # args is now the full input to RepeatableLayer + def __call__(self, *args): if not args: raise ValueError("RepeatableLayer expects at least one argument for initial data input.") - initial_data_input = args[0] # The first element is your main data input - static_block_args = args[1:] # Any subsequent elements are static args for each block + initial_data_input = args[0] + static_block_args = args[1:] - initial_index = jnp.array(0, dtype=jnp.int32) + initial_index = jnp.array(0, dtype=jnp.int32) #index of current transformer block scan_kwargs = {} if self.pspec_name is not None: @@ -92,9 +91,6 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer initializing = self.is_mutable_collection("params") params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - # in_axes for the scanned function (RepeatableCarryBlock.__call__): - # 1. The 'carry' tuple ((0, 0)) - # 2. Then, nn.broadcast for each of the `static_block_args` in_axes_for_scan = (nn.broadcast,) * (len(args)-1) scan_fn = nn.scan( @@ -117,5 +113,4 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer # Call wrapped_function with the initial carry tuple and the static_block_args (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) - # Typically, you only want the final data output from the sequence of layers return final_data \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 35750ad5e..5584fcc0f 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -117,21 +117,6 @@ def __init__( @classmethod def load_scheduler(cls, ckpt_path, config): - # scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - # "Wan-AI/Wan2.1-T2V-14B-Diffusers", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # scheduler, scheduler_state = FlaxRectifiedFlowMultistepScheduler.from_pretrained( - # "Lightricks/LTX-Video", - # subfolder="scheduler", - # flow_shift=5.0 # 5.0 for 720p, 3.0 for 480p - # ) - # import pdb; pdb.set_trace() - # scheduler = FlaxRectifiedFlowMultistepScheduler( - # sampler="LinearQuadratic" - # ) - # scheduler_state = scheduler.create_state() if config.sampler == "from_checkpoint" or not config.sampler: scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) else: @@ -159,14 +144,14 @@ def load_transformer(cls, config): for name in ignored_keys: if name in model_config: del model_config[name] + transformer = Transformer3DModel( - # change this sharding back **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) weights_init_fn = functools.partial( transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True ) - import pdb; pdb.set_trace() + ##load in jax weights checkpoint absolute_ckpt_path = os.path.abspath(relative_ckpt_path) checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) @@ -194,17 +179,11 @@ def load_vae(cls, ckpt_path): @classmethod def load_text_encoder(cls, ckpt_path): - # text_encoder = T5EncoderModel.from_pretrained( - # ckpt_path, subfolder="text_encoder" - # ) t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) return t5_encoder @classmethod def load_tokenizer(cls, config, ckpt_path): - # tokenizer = T5Tokenizer.from_pretrained( - # ckpt_path, subfolder="tokenizer" - # ) t5_tokenizer = AutoTokenizer.from_pretrained( ckpt_path, max_length=config.max_sequence_length, use_fast=True ) @@ -235,7 +214,7 @@ def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): config) # load from pytorch version - models_dir = "/mnt/disks/diffusionproj" #edit this! + models_dir = config.models_dir ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" if not os.path.isfile(ltxv_model_name_or_path): ltxv_model_path = hf_hub_download( @@ -292,6 +271,7 @@ def process(text: str): return text return [process(t) for t in text] + def denoising_step( scheduler, latents: Array, @@ -303,14 +283,6 @@ def denoising_step( t_eps: float = 1e-6, stochastic_sampling: bool = False, ) -> Array: - """ - Perform the denoising step for the required tokens, based on the current timestep and - conditioning mask: - Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) - and will start to be denoised when the current timestep is equal or lower than their - conditioning timestep. - (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) - """ # Denoise the latents using the scheduler denoised_latents = scheduler.step( noise_pred, @@ -343,7 +315,7 @@ def retrieve_timesteps( #currently doesn't support custom timesteps skip_initial_inference_steps < 0 or skip_final_inference_steps < 0 or skip_initial_inference_steps + skip_final_inference_steps - >= num_inference_steps # Use the original num_inference_steps here for the check + >= num_inference_steps ): raise ValueError( "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" @@ -378,13 +350,13 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] max_length = ( - text_encoder_max_tokens # TPU supports only lengths multiple of 128 + text_encoder_max_tokens ) if prompt_embeds is None: assert ( self.text_encoder is not None ), "You should provide either prompt_embeds or self.text_encoder should not be None," - # text_enc_device = next(self.text_encoder.parameters()) + prompt = self._text_preprocessing(prompt) text_inputs = self.tokenizer( prompt, @@ -419,25 +391,15 @@ def encode_prompt( else: dtype = None bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - print(isinstance(prompt_embeds, Array)) - # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - # prompt_embeds = prompt_embeds.view( - # bs_embed * num_images_per_prompt, seq_len, -1 - # ) prompt_embeds = jnp.reshape( prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) prompt_attention_mask = jnp.tile( prompt_attention_mask, (1, num_images_per_prompt)) - # prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) - # prompt_attention_mask = prompt_attention_mask.view( - # bs_embed * num_images_per_prompt, -1 - # ) prompt_attention_mask = jnp.reshape( prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - # get unconditional embeddings for classifier free guidance hasn't changed yet + # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = self._text_preprocessing(negative_prompt) uncond_tokens = uncond_tokens * batch_size @@ -460,7 +422,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = jnp.tile(negative_prompt_embeds, @@ -481,13 +442,13 @@ def encode_prompt( negative_prompt_attention_mask = None return ( - prompt_embeds,#(1, 256, 4096) - prompt_attention_mask, #1, 256 + prompt_embeds, + prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) - def prepare_latents( + def prepare_latents( ## this is in pytorch self, latents: torch.Tensor | None, media_items: torch.Tensor | None, @@ -556,123 +517,123 @@ def prepare_conditioning( assert isinstance(self.vae, CausalVideoAutoencoder) - # if conditioning_items: - # batch_size, _, num_latent_frames = init_latents.shape[:3] - - # init_conditioning_mask = torch.zeros( - # init_latents[:, 0, :, :, :].shape, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents = [] - # extra_conditioning_pixel_coords = [] - # extra_conditioning_mask = [] - # extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) - - # # Process each conditioning item - # for conditioning_item in conditioning_items: - # conditioning_item = self._resize_conditioning_item( - # conditioning_item, height, width - # ) - # media_item = conditioning_item.media_item - # media_frame_number = conditioning_item.media_frame_number - # strength = conditioning_item.conditioning_strength - # assert media_item.ndim == 5 # (b, c, f, h, w) - # b, c, n_frames, h, w = media_item.shape - # assert ( - # height == h and width == w - # ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - # assert n_frames % 8 == 1 - # assert ( - # media_frame_number >= 0 - # and media_frame_number + n_frames <= num_frames - # ) - - # # Encode the provided conditioning media item - # media_item_latents = vae_encode( - # media_item.to(dtype=self.vae.dtype, device=self.vae.device), - # self.vae, - # vae_per_channel_normalize=vae_per_channel_normalize, - # ).to(dtype=init_latents.dtype) - - # # Handle the different conditioning cases - # if media_frame_number == 0: - # # Get the target spatial position of the latent conditioning item - # media_item_latents, l_x, l_y = self._get_latent_spatial_position( - # media_item_latents, - # conditioning_item, - # height, - # width, - # strip_latent_border=True, - # ) - # b, c_l, f_l, h_l, w_l = media_item_latents.shape - - # # First frame or sequence - just update the initial noise latents and the mask - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( - # torch.lerp( - # init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], - # media_item_latents, - # strength, - # ) - # ) - # init_conditioning_mask[ - # :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l - # ] = strength - # else: - # # Non-first frame or sequence - # if n_frames > 1: - # # Handle non-first sequence. - # # Encoded latents are either fully consumed, or the prefix is handled separately below. - # ( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # ) = self._handle_non_first_conditioning_sequence( - # init_latents, - # init_conditioning_mask, - # media_item_latents, - # media_frame_number, - # strength, - # ) - - # # Single frame or sequence-prefix latents - # if media_item_latents is not None: - # noise = randn_tensor( - # media_item_latents.shape, - # generator=generator, - # device=media_item_latents.device, - # dtype=media_item_latents.dtype, - # ) - - # media_item_latents = torch.lerp( - # noise, media_item_latents, strength - # ) - - # # Patchify the extra conditioning latents and calculate their pixel coordinates - # media_item_latents, latent_coords = self.patchifier.patchify( - # latents=media_item_latents - # ) - # pixel_coords = latent_to_pixel_coords( - # latent_coords, - # self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, - # ) - - # # Update the frame numbers to match the target frame number - # pixel_coords[:, 0] += media_frame_number - # extra_conditioning_num_latents += media_item_latents.shape[1] - - # conditioning_mask = torch.full( - # media_item_latents.shape[:2], - # strength, - # dtype=torch.float32, - # device=init_latents.device, - # ) - - # extra_conditioning_latents.append(media_item_latents) - # extra_conditioning_pixel_coords.append(pixel_coords) - # extra_conditioning_mask.append(conditioning_mask) + if conditioning_items: + batch_size, _, num_latent_frames = init_latents.shape[:3] + + init_conditioning_mask = torch.zeros( + init_latents[:, 0, :, :, :].shape, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents = [] + extra_conditioning_pixel_coords = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # Process each conditioning item + for conditioning_item in conditioning_items: + conditioning_item = self._resize_conditioning_item( + conditioning_item, height, width + ) + media_item = conditioning_item.media_item + media_frame_number = conditioning_item.media_frame_number + strength = conditioning_item.conditioning_strength + assert media_item.ndim == 5 # (b, c, f, h, w) + b, c, n_frames, h, w = media_item.shape + assert ( + height == h and width == w + ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + assert n_frames % 8 == 1 + assert ( + media_frame_number >= 0 + and media_frame_number + n_frames <= num_frames + ) + + # Encode the provided conditioning media item + media_item_latents = vae_encode( + media_item.to(dtype=self.vae.dtype, device=self.vae.device), + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ).to(dtype=init_latents.dtype) + + # Handle the different conditioning cases + if media_frame_number == 0: + # Get the target spatial position of the latent conditioning item + media_item_latents, l_x, l_y = self._get_latent_spatial_position( + media_item_latents, + conditioning_item, + height, + width, + strip_latent_border=True, + ) + b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # First frame or sequence - just update the initial noise latents and the mask + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + torch.lerp( + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + media_item_latents, + strength, + ) + ) + init_conditioning_mask[ + :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + ] = strength + else: + # Non-first frame or sequence + if n_frames > 1: + # Handle non-first sequence. + # Encoded latents are either fully consumed, or the prefix is handled separately below. + ( + init_latents, + init_conditioning_mask, + media_item_latents, + ) = self._handle_non_first_conditioning_sequence( + init_latents, + init_conditioning_mask, + media_item_latents, + media_frame_number, + strength, + ) + + # Single frame or sequence-prefix latents + if media_item_latents is not None: + noise = randn_tensor( + media_item_latents.shape, + generator=generator, + device=media_item_latents.device, + dtype=media_item_latents.dtype, + ) + + media_item_latents = torch.lerp( + noise, media_item_latents, strength + ) + + # Patchify the extra conditioning latents and calculate their pixel coordinates + media_item_latents, latent_coords = self.patchifier.patchify( + latents=media_item_latents + ) + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.vae, + causal_fix=self.transformer.config.causal_temporal_positioning, + ) + + # Update the frame numbers to match the target frame number + pixel_coords[:, 0] += media_frame_number + extra_conditioning_num_latents += media_item_latents.shape[1] + + conditioning_mask = torch.full( + media_item_latents.shape[:2], + strength, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents.append(media_item_latents) + extra_conditioning_pixel_coords.append(pixel_coords) + extra_conditioning_mask.append(conditioning_mask) # Patchify the updated latents and calculate their pixel coordinates init_latents, init_latent_coords = self.patchifier.patchify( @@ -683,46 +644,45 @@ def prepare_conditioning( self.vae, # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now causal_fix=True - ) if not conditioning_items: return init_latents, init_pixel_coords, None, 0 - # init_conditioning_mask, _ = self.patchifier.patchify( - # latents=init_conditioning_mask.unsqueeze(1) - # ) - # init_conditioning_mask = init_conditioning_mask.squeeze(-1) - - # if extra_conditioning_latents: - # # Stack the extra conditioning latents, pixel coordinates and mask - # init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) - # init_pixel_coords = torch.cat( - # [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 - # ) - # init_conditioning_mask = torch.cat( - # [*extra_conditioning_mask, init_conditioning_mask], dim=1 - # ) - - # if self.transformer.use_tpu_flash_attention: - # # When flash attention is used, keep the original number of tokens by removing - # # tokens from the end. - # init_latents = init_latents[:, :-extra_conditioning_num_latents] - # init_pixel_coords = init_pixel_coords[ - # :, :, :-extra_conditioning_num_latents - # ] - # init_conditioning_mask = init_conditioning_mask[ - # :, :-extra_conditioning_num_latents - # ] - - # return ( - # init_latents, - # init_pixel_coords, - # init_conditioning_mask, - # extra_conditioning_num_latents, - # ) - - # change the paramters of these, currently pass in dummy inputs + init_conditioning_mask, _ = self.patchifier.patchify( + latents=init_conditioning_mask.unsqueeze(1) + ) + init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + if extra_conditioning_latents: + # Stack the extra conditioning latents, pixel coordinates and mask + init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + init_pixel_coords = torch.cat( + [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + ) + init_conditioning_mask = torch.cat( + [*extra_conditioning_mask, init_conditioning_mask], dim=1 + ) + + if self.transformer.use_tpu_flash_attention: + # When flash attention is used, keep the original number of tokens by removing + # tokens from the end. + init_latents = init_latents[:, :-extra_conditioning_num_latents] + init_pixel_coords = init_pixel_coords[ + :, :, :-extra_conditioning_num_latents + ] + init_conditioning_mask = init_conditioning_mask[ + :, :-extra_conditioning_num_latents + ] + + return ( + init_latents, + init_pixel_coords, + init_conditioning_mask, + extra_conditioning_num_latents, + ) + + def __call__( self, @@ -881,8 +841,7 @@ def __call__( prompt_attention_mask_batch = prompt_attention_mask if do_classifier_free_guidance: prompt_embeds_batch = jnp.concatenate( - [negative_prompt_embeds, prompt_embeds], axis=0 #check negative_prompt_embeds dimension - ) + [negative_prompt_embeds, prompt_embeds], axis=0) prompt_attention_mask_batch = jnp.concatenate( [negative_prompt_attention_mask, prompt_attention_mask], axis=0 ) @@ -898,7 +857,7 @@ def __call__( latents = self.prepare_latents( latents=latents, media_items=None, # set to None - timestep=scheduler_state.timesteps[0], # set to 1.0 for now TODO: fix this + timestep=scheduler_state.timesteps[0], latent_shape=latent_shape, dtype=None, device=None, @@ -925,73 +884,10 @@ def __call__( fractional_coords = pixel_coords.to(torch.float32) fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - # if not isinstance(guidance_scale, List): - # guidance_scale = [guidance_scale] * len(self.scheduler_state.timesteps) - # guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - # data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - # latents = jax.device_put(example_inputs["latents"], data_sharding) - # prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - # fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - # noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - # segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) noise_cond = jnp.ones( # initialize first round with this! (1, 1) ) - - # # noise_cond = None - # saved_tensor_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest1.0" - # tensor_dict = torch.load(saved_tensor_path) - - # for key, value in tensor_dict.items(): - # if value is not None: - # tensor_dict[key] = jnp.array(value.to(torch.float32).cpu().numpy()) - # example_inputs = tensor_dict - # latents = jax.device_put(example_inputs["latent_model_input"]) - # # prompt_embeds = jax.device_put(example_inputs["encoder_hidden_states"]) - # fractional_coords = jax.device_put(example_inputs["indices_grid"]) - # encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"]) - # segment_ids = None - # # validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids) - - # #only run this for the first time! - # scheduler_state = self.scheduler.set_timesteps(state=self.scheduler_state, shape=latents.shape, num_inference_steps=num_inference_steps) - # extra_step_kwargs = prepare_extra_step_kwargs(generator = jax.random.PRNGKey(0)) #check if this value needs to be changed, for unipc eta is not taken - # scheduler_state = self.scheduler_state - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = jax.jit( - # functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # ), - # in_shardings=(self.state_shardings, data_sharding, data_sharding, None), #not sure if this sharding is correct - # out_shardings=None, - # ) - segment_ids = None - # num_warmup_steps = max(len(self.scheduler_state.timesteps) - num_inference_steps * self.scheduler.order, 0) #no paramter order here - # p_run_inference = functools.partial( - # run_inference, - # transformer=self.transformer, - # config=self.config, - # mesh=self.mesh, - # fractional_cords=fractional_coords, - # prompt_embeds = prompt_embeds, - # segment_ids=segment_ids, - # encoder_attention_segment_ids=encoder_attention_segment_ids, - # num_inference_steps=num_inference_steps, - # scheduler=self.scheduler, - # # guidance_scale=guidance_scale - # ) + segment_ids = None #how is this created? p_run_inference = functools.partial( run_inference, transformer=self.transformer, @@ -1018,7 +914,6 @@ def __call__( with self.mesh: latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( - # add scheduler state back in torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) latents = torch.from_numpy(np.array(latents)) latents = latents[:, num_cond_latents:] @@ -1061,7 +956,7 @@ def __call__( timestep=decode_timestep, ) image = self.image_processor.postprocess( - image, output_type=output_type) # shape mismatch here + image, output_type=output_type) else: image = latents @@ -1072,7 +967,7 @@ def __call__( return (image,) return image - # save states here + def transformer_forward_pass( # need to jit this? wan didnt @@ -1102,21 +997,13 @@ def transformer_forward_pass( # need to jit this? wan didnt def run_inference( transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale ): - # do_classifier_free_guidance = guidance_scale > 1.0 - # for step in range(num_inference_steps): for i, t in enumerate(scheduler_state.timesteps): current_timestep = t latent_model_input = ( jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents ) - # t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - # # timestep = jnp.broadcast_to(t, timestep.shape) # (4, 256) - # timestep = jnp.broadcast_to( - # t, (latent_model_input.shape[0],) - # ).reshape(-1, 1) if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - # Determine the correct dtype based on the device (similar to PyTorch) - is_mps = False # MPS is not a JAX concept, remove it. JAX handles devices automatically. + is_mps = False if isinstance(current_timestep, float): dtype = jnp.float32 else: @@ -1129,49 +1016,43 @@ def run_inference( elif current_timestep.ndim == 0: current_timestep = jnp.expand_dims(current_timestep, axis=0) - # Broadcast to batch dimension (compatible with ONNX/Core ML) + # Broadcast to batch dimension current_timestep = jnp.broadcast_to( current_timestep, (latent_model_input.shape[0],1) ) - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( - skip_layer_masks[i] - if skip_layer_masks is not None - else None - )) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( + skip_layer_masks[i] + if skip_layer_masks is not None + else None + )) # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) - # # latents = self.denoising - # latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - # noise_pred, transformer_state = transformer_forward_pass(latents, transformer_state, timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids) #need to check if transformer_state is successfully updated + if do_spatio_temporal_guidance: - # JAX uses jnp.split for splitting along an axis. - # Equivalent to .chunk() for equal-sized chunks chunks = jnp.split(noise_pred, num_conds, axis=0) noise_pred_text = chunks[-2] noise_pred_text_perturb = chunks[-1] if do_classifier_free_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) noise_pred_uncond = chunks[0] noise_pred_text = chunks[1] if cfg_star_rescale: positive_flat = noise_pred_text.reshape(batch_size, -1) negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum( #(1, 1) + dot_product = jnp.sum( positive_flat * negative_flat, axis=1, keepdims=True ) - squared_norm = ( #(1, 1) + squared_norm = ( jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 ) - alpha = dot_product / squared_norm #might need to reshape this (1, 1) + alpha = dot_product / squared_norm alpha = alpha.reshape(batch_size, 1, 1) - noise_pred_uncond = alpha * noise_pred_uncond #error here (1, 3072, 128) + noise_pred_uncond = alpha * noise_pred_uncond noise_pred = noise_pred_uncond + guidance_scale[i] * ( noise_pred_text - noise_pred_uncond ) @@ -1191,7 +1072,7 @@ def run_inference( noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] # JAX slicing is similar + current_timestep = current_timestep[:1] latents, scheduler_state = scheduler.step( scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() @@ -1295,8 +1176,6 @@ def __call__( latents=upsampled_latents, reference_latents=latents ) - - latents = upsampled_latents output_type = original_output_type From c37547102158920a1c4197e381156c83c9b2afd8 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 18:47:16 +0000 Subject: [PATCH 32/34] multiscale pipeline --- .../checkpointing/checkpointing_utils.py | 2 +- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 202 +- src/maxdiffusion/max_utils.py | 2 +- .../ltx_video/autoencoders/causal_conv3d.py | 63 - .../autoencoders/causal_video_autoencoder.py | 1398 ----------- .../ltx_video/autoencoders/conv_nd_factory.py | 90 - .../ltx_video/autoencoders/dual_conv3d.py | 217 -- .../autoencoders/latent_upsampler.py | 203 -- .../ltx_video/autoencoders/pixel_norm.py | 12 - .../ltx_video/autoencoders/pixel_shuffle.py | 33 - .../models/ltx_video/autoencoders/vae.py | 380 --- .../ltx_video/autoencoders/vae_encode.py | 247 -- .../autoencoders/video_autoencoder.py | 1045 --------- .../{ => models}/autoencoders/__init__.py | 0 .../models/ltx_video/repeatable_layer.py | 130 +- .../ltx_video/transformers/attention.py | 1689 +++++++------- .../transformers/symmetric_patchifier.py | 84 - .../ltx_video/transformers/transformer3d.py | 580 +++-- .../utils/diffusers_config_mapping.py | 174 -- .../ltx_video/utils/prompt_enhance_utils.py | 226 -- .../ltx_video/utils/skip_layer_strategy.py | 8 - .../models/ltx_video/utils/torch_utils.py | 25 - .../pipelines/ltx_video/ltx_video_pipeline.py | 2062 ++++++++--------- .../schedulers/scheduling_rectified_flow.py | 422 ++-- 25 files changed, 2426 insertions(+), 6870 deletions(-) delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py delete mode 100644 src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py rename src/maxdiffusion/models/ltx_video/{ => models}/autoencoders/__init__.py (100%) delete mode 100644 src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py delete mode 100644 src/maxdiffusion/models/ltx_video/utils/torch_utils.py diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 5072b4639..b83e85a87 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -217,7 +217,7 @@ def load_state_if_possible( return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) else: item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index e54216c72..0fdbe7f9f 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -24,7 +24,7 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage." +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." height: 512 width: 512 num_frames: 88 #344 diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 9ad816564..f7d7e6d03 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -4,11 +4,8 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline from maxdiffusion import pyconfig -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler -from huggingface_hub import hf_hub_download import imageio from datetime import datetime - import os import torch from pathlib import Path @@ -18,52 +15,45 @@ def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: - # Calculate total padding needed - pad_height = target_height - source_height - pad_width = target_width - source_width + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width - # Calculate padding for each side - pad_top = pad_height // 2 - pad_bottom = pad_height - pad_top # Handles odd padding - pad_left = pad_width // 2 - pad_right = pad_width - pad_left # Handles odd padding + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding - # Return padded tensor - # Padding format is (left, right, top, bottom) - padding = (pad_left, pad_right, pad_top, pad_bottom) - return padding + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: - # Remove non-letters and convert to lowercase - clean_text = "".join( - char.lower() for char in text if char.isalpha() or char.isspace() - ) + # Remove non-letters and convert to lowercase + clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) - # Split into words - words = clean_text.split() + # Split into words + words = clean_text.split() - # Build result string keeping track of length - result = [] - current_length = 0 + # Build result string keeping track of length + result = [] + current_length = 0 - for word in words: - # Add word length plus 1 for underscore (except for first word) - new_length = current_length + len(word) + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) - if new_length <= max_len: - result.append(word) - current_length += len(word) - else: - break + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break - return "-".join(result) + return "-".join(result) -def create_latent_upsampler(latent_upsampler_model_path: str, device: str): - latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) - latent_upsampler.to(device) - latent_upsampler.eval() - return latent_upsampler def get_unique_filename( base: str, @@ -75,78 +65,82 @@ def get_unique_filename( endswith=None, index_range=1000, ) -> Path: - base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" - for i in range(index_range): - filename = dir / \ - f"{base_filename}_{i}{endswith if endswith else ''}{ext}" - if not os.path.exists(filename): - return filename - raise FileExistsError( - f"Could not find a unique filename after {index_range} attempts." - ) + base_filename = ( + f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + ) + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.") def run(config): - height_padded = ((config.height - 1) // 32 + 1) * 32 - width_padded = ((config.width - 1) // 32 + 1) * 32 - num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 - padding = calculate_padding( - config.height, config.width, height_padded, width_padded) - - seed = 10 - generator = torch.Generator().manual_seed(seed) - pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) - pipeline = LTXMultiScalePipeline(pipeline) - images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config) - (pad_left, pad_right, pad_top, pad_bottom) = padding - pad_bottom = -pad_bottom - pad_right = -pad_right - if pad_bottom == 0: - pad_bottom = images.shape[3] - if pad_right == 0: - pad_right = images.shape[4] - images = images[:, :, :config.num_frames, - pad_top:pad_bottom, pad_left:pad_right] - output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") - output_dir.mkdir(parents=True, exist_ok=True) - for i in range(images.shape[0]): - # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C - video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() - # Unnormalizing images to [0, 255] range - video_np = (video_np * 255).astype(np.uint8) - fps = config.frame_rate - height, width = video_np.shape[1:3] - # In case a single image is generated - if video_np.shape[0] == 1: - output_filename = get_unique_filename( - f"image_output_{i}", - ".png", - prompt=config.prompt, - seed=seed, - resolution=(height, width, config.num_frames), - dir=output_dir, - ) - imageio.imwrite(output_filename, video_np[0]) - else: - output_filename = get_unique_filename( - f"video_output_{i}", - ".mp4", - prompt=config.prompt, - seed=seed, - resolution=(height, width, config.num_frames), - dir=output_dir, - ) - print(output_filename) - # Write video - with imageio.get_writer(output_filename, fps=fps) as video: - for frame in video_np: - video.append_data(frame) + height_padded = ((config.height - 1) // 32 + 1) * 32 + width_padded = ((config.width - 1) // 32 + 1) * 32 + num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 + padding = calculate_padding(config.height, config.width, height_padded, width_padded) + + seed = 10 + generator = torch.Generator().manual_seed(seed) + pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False) + pipeline = LTXMultiScalePipeline(pipeline) + images = pipeline( + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + output_type="pt", + generator=generator, + config=config, + ) + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] + output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + output_dir.mkdir(parents=True, exist_ok=True) + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = config.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=config.prompt, + seed=seed, + resolution=(height, width, config.num_frames), + dir=output_dir, + ) + print(output_filename) + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index e13f31f94..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -612,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py deleted file mode 100644 index 98249c2f5..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Tuple, Union - -import torch -import torch.nn as nn - - -class CausalConv3d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size: int = 3, - stride: Union[int, Tuple[int]] = 1, - dilation: int = 1, - groups: int = 1, - spatial_padding_mode: str = "zeros", - **kwargs, - ): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - - kernel_size = (kernel_size, kernel_size, kernel_size) - self.time_kernel_size = kernel_size[0] - - dilation = (dilation, 1, 1) - - height_pad = kernel_size[1] // 2 - width_pad = kernel_size[2] // 2 - padding = (0, height_pad, width_pad) - - self.conv = nn.Conv3d( - in_channels, - out_channels, - kernel_size, - stride=stride, - dilation=dilation, - padding=padding, - padding_mode=spatial_padding_mode, - groups=groups, - ) - - def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x - - @property - def weight(self): - return self.conv.weight diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py deleted file mode 100644 index 1255b6d34..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/causal_video_autoencoder.py +++ /dev/null @@ -1,1398 +0,0 @@ -import json -import os -from functools import partial -from types import SimpleNamespace -from typing import Any, Mapping, Optional, Tuple, Union, List -from pathlib import Path - -import torch -import numpy as np -from einops import rearrange -from torch import nn -from diffusers.utils import logging -import torch.nn.functional as F -from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings -from safetensors import safe_open - - -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm -from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND -from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper -from maxdiffusion.models.ltx_video.transformers.attention import Attention -from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( - diffusers_and_ours_config_mapping, - make_hashable_key, - VAE_KEYS_RENAME_DICT, -) - -PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class CausalVideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_name_or_path = Path(pretrained_model_name_or_path) - if ( - pretrained_model_name_or_path.is_dir() - and (pretrained_model_name_or_path / "autoencoder.pth").exists() - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - std_of_means = data_dict["std-of-means"] - mean_of_means = data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( - std_of_means - ) - state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( - mean_of_means - ) - - elif pretrained_model_name_or_path.is_dir(): - config_path = pretrained_model_name_or_path / "vae" / "config.json" - with open(config_path, "r") as f: - config = make_hashable_key(json.load(f)) - - assert config in diffusers_and_ours_config_mapping, ( - "Provided diffusers checkpoint config for VAE is not suppported. " - "We only support diffusers configs found in Lightricks/LTX-Video." - ) - - config = diffusers_and_ours_config_mapping[config] - - state_dict_path = ( - pretrained_model_name_or_path - / "vae" - / "diffusion_pytorch_model.safetensors" - ) - - state_dict = {} - with safe_open(state_dict_path, framework="pt", device="cpu") as f: - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - for key in list(state_dict.keys()): - new_key = key - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - - state_dict[new_key] = state_dict.pop(key) - - elif pretrained_model_name_or_path.is_file() and str( - pretrained_model_name_or_path - ).endswith(".safetensors"): - state_dict = {} - with safe_open( - pretrained_model_name_or_path, framework="pt", device="cpu" - ) as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - configs = json.loads(metadata["config"]) - config = configs["vae"] - - video_vae = cls.from_config(config) - if "torch_dtype" in kwargs: - video_vae.to(kwargs["torch_dtype"]) - video_vae.load_state_dict(state_dict) - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "CausalVideoAutoencoder" - ), "config must have _class_name=CausalVideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - normalize_latent_channels = config.get("normalize_latent_channels", False) - - if use_quant_conv and latent_log_var in ["uniform", "constant"]: - raise ValueError( - f"latent_log_var={latent_log_var} requires use_quant_conv=False" - ) - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - base_channels=config.get("encoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("blocks")), - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - causal=config.get("causal_decoder", False), - timestep_conditioning=config.get("timestep_conditioning", False), - base_channels=config.get("decoder_base_channels", 128), - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), - ) - - dims = config["dims"] - return CausalVideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - normalize_latent_channels=normalize_latent_channels, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="CausalVideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, - out_channels=self.decoder.conv_out.out_channels - // self.decoder.patch_size**2, - latent_channels=self.decoder.conv_in.in_channels, - encoder_blocks=self.encoder.blocks_desc, - decoder_blocks=self.decoder.blocks_desc, - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - causal_decoder=self.decoder.causal, - timestep_conditioning=self.decoder.timestep_conditioning, - normalize_latent_channels=self.normalize_latent_channels, - ) - - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def spatial_downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.encoder.blocks_desc - if block[0] - in [ - "compress_space", - "compress_all", - "compress_all_res", - "compress_space_res", - ] - ] - ) - * self.encoder.patch_size - ) - - @property - def temporal_downscale_factor(self): - return 2 ** len( - [ - block - for block in self.encoder.blocks_desc - if block[0] - in [ - "compress_time", - "compress_all", - "compress_all_res", - "compress_time_res", - ] - ] - ) - - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - if any([key.startswith("vae.") for key in state_dict.keys()]): - state_dict = { - key.replace("vae.", ""): value - for key, value in state_dict.items() - if key.startswith("vae.") - } - ckpt_state_dict = { - key: value - for key, value in state_dict.items() - if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - - model_keys = set(name for name, _ in self.named_modules()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - converted_state_dict = {} - for key, value in ckpt_state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - key_prefix = ".".join(key.split(".")[:-1]) - if "norm" in key and key_prefix not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - data_dict = { - key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value - for key, value in state_dict.items() - if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) - } - if len(data_dict) > 0: - self.register_buffer("std_of_means", data_dict["std-of-means"]) - self.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) - - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer - - def set_use_tpu_flash_attention(self): - for block in self.decoder.up_blocks: - if isinstance(block, UNetMidBlock3D) and block.attention_blocks: - for attention_block in block.attention_blocks: - attention_block.set_use_tpu_flash_attention() - - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - self.blocks_desc = blocks - - in_channels = in_channels * patch_size**2 - output_channel = base_channels - - self.conv_in = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.down_blocks = nn.ModuleList([]) - - for block_name, block_params in blocks: - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 1, 1), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(1, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_x_y": - output_channel = block_params.get("multiplier", 2) * output_channel - block = make_conv_nd( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - kernel_size=3, - stride=(2, 2, 2), - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time_res": - output_channel = block_params.get("multiplier", 2) * output_channel - block = SpaceToDepthDownsample( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown block: {block_name}") - - self.down_blocks.append(block) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var == "constant": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, - output_channel, - conv_out_channels, - 3, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)(sample) - - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - elif self.latent_log_var == "constant": - sample = sample[:, :-1, ...] - approx_ln_0 = ( - -30 - ) # this is the minimal clamp value in DiagonalGaussianDistribution objects - sample = torch.cat( - [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], - dim=1, - ) - - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): - The number of dimensions to use in convolutions. - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): - The blocks to use. Each block is a tuple of the block name and the number of layers. - base_channels (`int`, *optional*, defaults to 128): - The number of output channels for the first convolutional layer. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - causal (`bool`, *optional*, defaults to `True`): - Whether to use causal convolutions or not. - """ - - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], - base_channels: int = 128, - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - causal: bool = True, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.patch_size = patch_size - self.layers_per_block = layers_per_block - out_channels = out_channels * patch_size**2 - self.causal = causal - self.blocks_desc = blocks - - # Compute output channel to be product of all channel-multiplier blocks - output_channel = base_channels - for block_name, block_params in list(reversed(blocks)): - block_params = block_params if isinstance(block_params, dict) else {} - if block_name == "res_x_y": - output_channel = output_channel * block_params.get("multiplier", 2) - if block_name.startswith("compress"): - output_channel = output_channel * block_params.get("multiplier", 1) - - self.conv_in = make_conv_nd( - dims, - in_channels, - output_channel, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.up_blocks = nn.ModuleList([]) - - for block_name, block_params in list(reversed(blocks)): - input_channel = output_channel - if isinstance(block_params, int): - block_params = {"num_layers": block_params} - - if block_name == "res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "res_x_y": - output_channel = output_channel // block_params.get("multiplier", 2) - block = ResnetBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - eps=1e-6, - groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=False, - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_time": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 1, 1), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_space": - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(1, 2, 2), - spatial_padding_mode=spatial_padding_mode, - ) - elif block_name == "compress_all": - output_channel = output_channel // block_params.get("multiplier", 1) - block = DepthToSpaceUpsample( - dims=dims, - in_channels=input_channel, - stride=(2, 2, 2), - residual=block_params.get("residual", False), - out_channels_reduction_factor=block_params.get("multiplier", 1), - spatial_padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unknown layer: {block_name}") - - self.up_blocks.append(block) - - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - elif norm_layer == "layer_norm": - self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) - - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, - output_channel, - out_channels, - 3, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - self.gradient_checkpointing = False - - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter( - torch.tensor(1000.0, dtype=torch.float32) - ) - self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - output_channel * 2, 0 - ) - self.last_scale_shift_table = nn.Parameter( - torch.randn(2, output_channel) / output_channel**0.5 - ) - - def forward( - self, - sample: torch.FloatTensor, - target_shape, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - batch_size = sample.shape[0] - - sample = self.conv_in(sample, causal=self.causal) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - sample = sample.to(upscale_dtype) - - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: - embedded_timestep = self.last_time_embedder( - timestep=scaled_timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=sample.shape[0], - hidden_dtype=sample.dtype, - ) - embedded_timestep = embedded_timestep.view( - batch_size, embedded_timestep.shape[-1], 1, 1, 1 - ) - ada_values = self.last_scale_shift_table[ - None, ..., None, None, None - ] + embedded_timestep.reshape( - batch_size, - 2, - -1, - embedded_timestep.shape[-3], - embedded_timestep.shape[-2], - embedded_timestep.shape[-1], - ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift - - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) - - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - - return sample - - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - inject_noise (`bool`, *optional*, defaults to `False`): - Whether to inject noise into the hidden states. - timestep_conditioning (`bool`, *optional*, defaults to `False`): - Whether to condition the hidden states on the timestep. - attention_head_dim (`int`, *optional*, defaults to -1): - The dimension of the attention head. If -1, no attention is used. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - attention_head_dim: int = -1, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( - in_channels * 4, 0 - ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - inject_noise=inject_noise, - timestep_conditioning=timestep_conditioning, - spatial_padding_mode=spatial_padding_mode, - ) - for _ in range(num_layers) - ] - ) - - self.attention_blocks = None - - if attention_head_dim > 0: - if attention_head_dim > in_channels: - raise ValueError( - "attention_head_dim must be less than or equal to in_channels" - ) - - self.attention_blocks = nn.ModuleList( - [ - Attention( - query_dim=in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - bias=True, - out_bias=True, - qk_norm="rms_norm", - residual_connection=True, - ) - for _ in range(num_layers) - ] - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - timestep_embed = None - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - batch_size = hidden_states.shape[0] - timestep_embed = self.time_embedder( - timestep=timestep.flatten(), - resolution=None, - aspect_ratio=None, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - timestep_embed = timestep_embed.view( - batch_size, timestep_embed.shape[-1], 1, 1, 1 - ) - - if self.attention_blocks: - for resnet, attention in zip(self.res_blocks, self.attention_blocks): - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - # Reshape the hidden states to be (batch_size, frames * height * width, channel) - batch_size, channel, frames, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, frames * height * width - ).transpose(1, 2) - - if attention.use_tpu_flash_attention: - # Pad the second dimension to be divisible by block_k_major (block in flash attention) - seq_len = hidden_states.shape[1] - block_k_major = 512 - pad_len = (block_k_major - seq_len % block_k_major) % block_k_major - if pad_len > 0: - hidden_states = F.pad( - hidden_states, (0, 0, 0, pad_len), "constant", 0 - ) - - # Create a mask with ones for the original sequence length and zeros for the padded indexes - mask = torch.ones( - (hidden_states.shape[0], seq_len), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - if pad_len > 0: - mask = F.pad(mask, (0, pad_len), "constant", 0) - - hidden_states = attention( - hidden_states, - attention_mask=( - None if not attention.use_tpu_flash_attention else mask - ), - ) - - if attention.use_tpu_flash_attention: - # Remove the padding - if pad_len > 0: - hidden_states = hidden_states[:, :-pad_len, :] - - # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, frames, height, width - ) - else: - for resnet in self.res_blocks: - hidden_states = resnet( - hidden_states, causal=causal, timestep=timestep_embed - ) - - return hidden_states - - -class SpaceToDepthDownsample(nn.Module): - def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): - super().__init__() - self.stride = stride - self.group_size = in_channels * np.prod(stride) // out_channels - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels // np.prod(stride), - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - def forward(self, x, causal: bool = True): - if self.stride[0] == 2: - x = torch.cat( - [x[:, :, :1, :, :], x], dim=2 - ) # duplicate first frames for padding - - # skip connection - x_in = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) - x_in = x_in.mean(dim=2) - - # conv - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - - x = x + x_in - - return x - - -class DepthToSpaceUpsample(nn.Module): - def __init__( - self, - dims, - in_channels, - stride, - residual=False, - out_channels_reduction_factor=1, - spatial_padding_mode="zeros", - ): - super().__init__() - self.stride = stride - self.out_channels = ( - np.prod(stride) * in_channels // out_channels_reduction_factor - ) - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) - self.residual = residual - self.out_channels_reduction_factor = out_channels_reduction_factor - - def forward(self, x, causal: bool = True): - if self.residual: - # Reshape and duplicate the input to match the output shape - x_in = self.pixel_shuffle(x) - num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor - x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: - x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = self.pixel_shuffle(x) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x - - -class LayerNorm(nn.Module): - def __init__(self, dim, eps, elementwise_affine=True) -> None: - super().__init__() - self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward(self, x): - x = rearrange(x, "b c d h w -> b d h w c") - x = self.norm(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x - - -class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - inject_noise: bool = False, - timestep_conditioning: bool = False, - spatial_padding_mode: str = "zeros", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.inject_noise = inject_noise - - if norm_layer == "group_norm": - self.norm1 = nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) - - self.non_linearity = nn.SiLU() - - self.conv1 = make_conv_nd( - dims, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - if inject_noise: - self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) - - if norm_layer == "group_norm": - self.norm2 = nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() - elif norm_layer == "layer_norm": - self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = make_conv_nd( - dims, - out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - causal=True, - spatial_padding_mode=spatial_padding_mode, - ) - - if inject_noise: - self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) - - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) - - self.norm3 = ( - LayerNorm(in_channels, eps=eps, elementwise_affine=True) - if in_channels != out_channels - else nn.Identity() - ) - - self.timestep_conditioning = timestep_conditioning - - if timestep_conditioning: - self.scale_shift_table = nn.Parameter( - torch.randn(4, in_channels) / in_channels**0.5 - ) - - def _feed_spatial_noise( - self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor - ) -> torch.FloatTensor: - spatial_shape = hidden_states.shape[-2:] - device = hidden_states.device - dtype = hidden_states.dtype - - # similar to the "explicit noise inputs" method in style-gan - spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] - scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] - hidden_states = hidden_states + scaled_noise - - return hidden_states - - def forward( - self, - input_tensor: torch.FloatTensor, - causal: bool = True, - timestep: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - hidden_states = input_tensor - batch_size = hidden_states.shape[0] - - hidden_states = self.norm1(hidden_states) - if self.timestep_conditioning: - assert ( - timestep is not None - ), "should pass timestep with timestep_conditioning=True" - ada_values = self.scale_shift_table[ - None, ..., None, None, None - ] + timestep.reshape( - batch_size, - 4, - -1, - timestep.shape[-3], - timestep.shape[-2], - timestep.shape[-1], - ) - shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) - - hidden_states = hidden_states * (1 + scale1) + shift1 - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.conv1(hidden_states, causal=causal) - - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 - ) - - hidden_states = self.norm2(hidden_states) - - if self.timestep_conditioning: - hidden_states = hidden_states * (1 + scale2) + shift2 - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - - hidden_states = self.conv2(hidden_states, causal=causal) - - if self.inject_noise: - hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 - ) - - input_tensor = self.norm3(input_tensor) - - batch_size = input_tensor.shape[0] - - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -def patchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - return x - - -def unpatchify(x, patch_size_hw, patch_size_t=1): - if patch_size_hw == 1 and patch_size_t == 1: - return x - - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - - return x - - -def create_video_autoencoder_demo_config( - latent_channels: int = 64, -): - encoder_blocks = [ - ("res_x", {"num_layers": 2}), - ("compress_space_res", {"multiplier": 2}), - ("compress_time_res", {"multiplier": 2}), - ("compress_all_res", {"multiplier": 2}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 1}), - ] - decoder_blocks = [ - ("res_x", {"num_layers": 2, "inject_noise": False}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("compress_all", {"residual": True, "multiplier": 2}), - ("res_x", {"num_layers": 2, "inject_noise": False}), - ] - return { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "encoder_blocks": encoder_blocks, - "decoder_blocks": decoder_blocks, - "latent_channels": latent_channels, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, - "timestep_conditioning": True, - "spatial_padding_mode": "replicate", - } - - -def test_vae_patchify_unpatchify(): - import torch - - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) - - -def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_demo_config() - - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = CausalVideoAutoencoder.from_config(config) - - print(video_autoencoder) - video_autoencoder.eval() - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 17, 64, 64) - - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") - - timestep = torch.ones(input_videos.shape[0]) * 0.1 - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape, timestep=timestep - ).sample - - print(f"reconstructed shape={reconstructed_videos.shape}") - - # Validate that single image gets treated the same way as first frame - input_image = input_videos[:, :, :1, :, :] - image_latent = video_autoencoder.encode(input_image).latent_dist.mode() - _ = video_autoencoder.decode( - image_latent, target_shape=image_latent.shape, timestep=timestep - ).sample - - first_frame_latent = latent[:, :, :1, :, :] - - assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) - # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) - # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() - - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - - # Perform backward pass - loss.backward() - - print(f"Demo completed with loss: {loss.item()}") - - -# Ensure to call the demo function to execute the forward and backward pass -if __name__ == "__main__": - demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py b/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py deleted file mode 100644 index 1aa55ed9c..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/conv_nd_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple, Union - -import torch - -from maxdiffusion.models.ltx_video.autoencoders.dual_conv3d import DualConv3d -from maxdiffusion.models.ltx_video.autoencoders.causal_conv3d import CausalConv3d - - -def make_conv_nd( - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - kernel_size: int, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - causal=False, - spatial_padding_mode="zeros", - temporal_padding_mode="zeros", -): - if not (spatial_padding_mode == temporal_padding_mode or causal): - raise NotImplementedError("spatial and temporal padding modes must be equal") - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == 3: - if causal: - return CausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - spatial_padding_mode=spatial_padding_mode, - ) - return torch.nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=spatial_padding_mode, - ) - elif dims == (2, 1): - return DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=bias, - padding_mode=spatial_padding_mode, - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - -def make_linear_nd( - dims: int, - in_channels: int, - out_channels: int, - bias=True, -): - if dims == 2: - return torch.nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias - ) - else: - raise ValueError(f"unsupported dimensions: {dims}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py b/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py deleted file mode 100644 index dcf889296..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/dual_conv3d.py +++ /dev/null @@ -1,217 +0,0 @@ -import math -from typing import Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - - -class DualConv3d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - dilation: Union[int, Tuple[int, int, int]] = 1, - groups=1, - bias=True, - padding_mode="zeros", - ): - super(DualConv3d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.padding_mode = padding_mode - # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size, kernel_size) - if kernel_size == (1, 1, 1): - raise ValueError( - "kernel_size must be greater than 1. Use make_linear_nd instead." - ) - if isinstance(stride, int): - stride = (stride, stride, stride) - if isinstance(padding, int): - padding = (padding, padding, padding) - if isinstance(dilation, int): - dilation = (dilation, dilation, dilation) - - # Set parameters for convolutions - self.groups = groups - self.bias = bias - - # Define the size of the channels after the first convolution - intermediate_channels = ( - out_channels if in_channels < out_channels else in_channels - ) - - # Define parameters for the first convolution - self.weight1 = nn.Parameter( - torch.Tensor( - intermediate_channels, - in_channels // groups, - 1, - kernel_size[1], - kernel_size[2], - ) - ) - self.stride1 = (1, stride[1], stride[2]) - self.padding1 = (0, padding[1], padding[2]) - self.dilation1 = (1, dilation[1], dilation[2]) - if bias: - self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) - else: - self.register_parameter("bias1", None) - - # Define parameters for the second convolution - self.weight2 = nn.Parameter( - torch.Tensor( - out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 - ) - ) - self.stride2 = (stride[0], 1, 1) - self.padding2 = (padding[0], 0, 0) - self.dilation2 = (dilation[0], 1, 1) - if bias: - self.bias2 = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter("bias2", None) - - # Initialize weights and biases - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) - nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) - if self.bias: - fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) - bound1 = 1 / math.sqrt(fan_in1) - nn.init.uniform_(self.bias1, -bound1, bound1) - fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) - bound2 = 1 / math.sqrt(fan_in2) - nn.init.uniform_(self.bias2, -bound2, bound2) - - def forward(self, x, use_conv3d=False, skip_time_conv=False): - if use_conv3d: - return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) - else: - return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) - - def forward_with_3d(self, x, skip_time_conv): - # First convolution - x = F.conv3d( - x, - self.weight1, - self.bias1, - self.stride1, - self.padding1, - self.dilation1, - self.groups, - padding_mode=self.padding_mode, - ) - - if skip_time_conv: - return x - - # Second convolution - x = F.conv3d( - x, - self.weight2, - self.bias2, - self.stride2, - self.padding2, - self.dilation2, - self.groups, - padding_mode=self.padding_mode, - ) - - return x - - def forward_with_2d(self, x, skip_time_conv): - b, c, d, h, w = x.shape - - # First 2D convolution - x = rearrange(x, "b c d h w -> (b d) c h w") - # Squeeze the depth dimension out of weight1 since it's 1 - weight1 = self.weight1.squeeze(2) - # Select stride, padding, and dilation for the 2D convolution - stride1 = (self.stride1[1], self.stride1[2]) - padding1 = (self.padding1[1], self.padding1[2]) - dilation1 = (self.dilation1[1], self.dilation1[2]) - x = F.conv2d( - x, - weight1, - self.bias1, - stride1, - padding1, - dilation1, - self.groups, - padding_mode=self.padding_mode, - ) - - _, _, h, w = x.shape - - if skip_time_conv: - x = rearrange(x, "(b d) c h w -> b c d h w", b=b) - return x - - # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) - - # Reshape weight2 to match the expected dimensions for conv1d - weight2 = self.weight2.squeeze(-1).squeeze(-1) - # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution - stride2 = self.stride2[0] - padding2 = self.padding2[0] - dilation2 = self.dilation2[0] - x = F.conv1d( - x, - weight2, - self.bias2, - stride2, - padding2, - dilation2, - self.groups, - padding_mode=self.padding_mode, - ) - x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) - - return x - - @property - def weight(self): - return self.weight2 - - -def test_dual_conv3d_consistency(): - # Initialize parameters - in_channels = 3 - out_channels = 5 - kernel_size = (3, 3, 3) - stride = (2, 2, 2) - padding = (1, 1, 1) - - # Create an instance of the DualConv3d class - dual_conv3d = DualConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=True, - ) - - # Example input tensor - test_input = torch.randn(1, 3, 10, 10, 10) - - # Perform forward passes with both 3D and 2D settings - output_conv3d = dual_conv3d(test_input, use_conv3d=True) - output_2d = dual_conv3d(test_input, use_conv3d=False) - - # Assert that the outputs from both methods are sufficiently close - assert torch.allclose( - output_conv3d, output_2d, atol=1e-6 - ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py b/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py deleted file mode 100644 index 8cb7d7d68..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/latent_upsampler.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Optional, Union -from pathlib import Path -import os -import json - -import torch -import torch.nn as nn -from einops import rearrange -from diffusers import ConfigMixin, ModelMixin -from safetensors.torch import safe_open - -from maxdiffusion.models.ltx_video.autoencoders.pixel_shuffle import PixelShuffleND - - -class ResBlock(nn.Module): - def __init__( - self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 - ): - super().__init__() - if mid_channels is None: - mid_channels = channels - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) - self.norm1 = nn.GroupNorm(32, mid_channels) - self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) - self.norm2 = nn.GroupNorm(32, channels) - self.activation = nn.SiLU() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = self.activation(x) - x = self.conv2(x) - x = self.norm2(x) - x = self.activation(x + residual) - return x - - -class LatentUpsampler(ModelMixin, ConfigMixin): - """ - Model to spatially upsample VAE latents. - - Args: - in_channels (`int`): Number of channels in the input latent - mid_channels (`int`): Number of channels in the middle layers - num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) - dims (`int`): Number of dimensions for convolutions (2 or 3) - spatial_upsample (`bool`): Whether to spatially upsample the latent - temporal_upsample (`bool`): Whether to temporally upsample the latent - """ - - def __init__( - self, - in_channels: int = 128, - mid_channels: int = 512, - num_blocks_per_stage: int = 4, - dims: int = 3, - spatial_upsample: bool = True, - temporal_upsample: bool = False, - ): - super().__init__() - - self.in_channels = in_channels - self.mid_channels = mid_channels - self.num_blocks_per_stage = num_blocks_per_stage - self.dims = dims - self.spatial_upsample = spatial_upsample - self.temporal_upsample = temporal_upsample - - Conv = nn.Conv2d if dims == 2 else nn.Conv3d - - self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) - self.initial_norm = nn.GroupNorm(32, mid_channels) - self.initial_activation = nn.SiLU() - - self.res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - if spatial_upsample and temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(3), - ) - elif spatial_upsample: - self.upsampler = nn.Sequential( - nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(2), - ) - elif temporal_upsample: - self.upsampler = nn.Sequential( - nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(1), - ) - else: - raise ValueError( - "Either spatial_upsample or temporal_upsample must be True" - ) - - self.post_upsample_res_blocks = nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] - ) - - self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, latent: torch.Tensor) -> torch.Tensor: - b, c, f, h, w = latent.shape - - if self.dims == 2: - x = rearrange(latent, "b c f h w -> (b f) c h w") - x = self.initial_conv(x) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - x = self.upsampler(x) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - else: - x = self.initial_conv(latent) - x = self.initial_norm(x) - x = self.initial_activation(x) - - for block in self.res_blocks: - x = block(x) - - if self.temporal_upsample: - x = self.upsampler(x) - x = x[:, :, 1:, :, :] - else: - x = rearrange(x, "b c f h w -> (b f) c h w") - x = self.upsampler(x) - x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) - - for block in self.post_upsample_res_blocks: - x = block(x) - - x = self.final_conv(x) - - return x - - @classmethod - def from_config(cls, config): - return cls( - in_channels=config.get("in_channels", 4), - mid_channels=config.get("mid_channels", 128), - num_blocks_per_stage=config.get("num_blocks_per_stage", 4), - dims=config.get("dims", 2), - spatial_upsample=config.get("spatial_upsample", True), - temporal_upsample=config.get("temporal_upsample", False), - ) - - def config(self): - return { - "_class_name": "LatentUpsampler", - "in_channels": self.in_channels, - "mid_channels": self.mid_channels, - "num_blocks_per_stage": self.num_blocks_per_stage, - "dims": self.dims, - "spatial_upsample": self.spatial_upsample, - "temporal_upsample": self.temporal_upsample, - } - - @classmethod - def from_pretrained( - cls, - pretrained_model_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - pretrained_model_path = Path(pretrained_model_path) - if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( - ".safetensors" - ): - state_dict = {} - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - config = json.loads(metadata["config"]) - with torch.device("meta"): - latent_upsampler = LatentUpsampler.from_config(config) - latent_upsampler.load_state_dict(state_dict, assign=True) - return latent_upsampler - - -if __name__ == "__main__": - latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) - print(latent_upsampler) - total_params = sum(p.numel() for p in latent_upsampler.parameters()) - print(f"Total number of parameters: {total_params:,}") - latent = torch.randn(1, 128, 9, 16, 16) - upsampled_latent = latent_upsampler(latent) - print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py deleted file mode 100644 index 9bc3ea60e..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_norm.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch import nn - - -class PixelNorm(nn.Module): - def __init__(self, dim=1, eps=1e-8): - super(PixelNorm, self).__init__() - self.dim = dim - self.eps = eps - - def forward(self, x): - return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py b/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py deleted file mode 100644 index 4e79ae284..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/pixel_shuffle.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch.nn as nn -from einops import rearrange - - -class PixelShuffleND(nn.Module): - def __init__(self, dims, upscale_factors=(2, 2, 2)): - super().__init__() - assert dims in [1, 2, 3], "dims must be 1, 2, or 3" - self.dims = dims - self.upscale_factors = upscale_factors - - def forward(self, x): - if self.dims == 3: - return rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - p3=self.upscale_factors[2], - ) - elif self.dims == 2: - return rearrange( - x, - "b (c p1 p2) h w -> b c (h p1) (w p2)", - p1=self.upscale_factors[0], - p2=self.upscale_factors[1], - ) - elif self.dims == 1: - return rearrange( - x, - "b (c p1) f h w -> b c (f p1) h w", - p1=self.upscale_factors[0], - ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae.py deleted file mode 100644 index 821a6b32b..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae.py +++ /dev/null @@ -1,380 +0,0 @@ -from typing import Optional, Union - -import torch -import inspect -import math -import torch.nn as nn -from diffusers import ConfigMixin, ModelMixin -from diffusers.models.autoencoders.vae import ( - DecoderOutput, - DiagonalGaussianDistribution, -) -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd - - -class AutoencoderKLWrapper(ModelMixin, ConfigMixin): - """Variational Autoencoder (VAE) model with KL loss. - - VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. - This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. - - Args: - encoder (`nn.Module`): - Encoder module. - decoder (`nn.Module`): - Decoder module. - latent_channels (`int`, *optional*, defaults to 4): - Number of latent channels. - """ - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - latent_channels: int = 4, - dims: int = 2, - sample_size=512, - use_quant_conv: bool = True, - normalize_latent_channels: bool = False, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = encoder - self.use_quant_conv = use_quant_conv - self.normalize_latent_channels = normalize_latent_channels - - # pass init params to Decoder - quant_dims = 2 if dims == 2 else 3 - self.decoder = decoder - if use_quant_conv: - self.quant_conv = make_conv_nd( - quant_dims, 2 * latent_channels, 2 * latent_channels, 1 - ) - self.post_quant_conv = make_conv_nd( - quant_dims, latent_channels, latent_channels, 1 - ) - else: - self.quant_conv = nn.Identity() - self.post_quant_conv = nn.Identity() - - if normalize_latent_channels: - if dims == 2: - self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) - else: - self.latent_norm_out = nn.Identity() - self.use_z_tiling = False - self.use_hw_tiling = False - self.dims = dims - self.z_sample_size = 1 - - self.decoder_params = inspect.signature(self.decoder.forward).parameters - - # only relevant if vae tiling is enabled - self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) - - def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): - self.tile_sample_min_size = sample_size - num_blocks = len(self.encoder.down_blocks) - self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) - self.tile_overlap_factor = overlap_factor - - def enable_z_tiling(self, z_sample_size: int = 8): - r""" - Enable tiling during VAE decoding. - - When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_z_tiling = z_sample_size > 1 - self.z_sample_size = z_sample_size - assert ( - z_sample_size % 8 == 0 or z_sample_size == 1 - ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." - - def disable_z_tiling(self): - r""" - Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_z_tiling = False - - def enable_hw_tiling(self): - r""" - Enable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = True - - def disable_hw_tiling(self): - r""" - Disable tiling during VAE decoding along the height and width dimension. - """ - self.use_hw_tiling = False - - def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent - - # Split the image into 512x512 tiles and encode them separately. - rows = [] - for i in range(0, x.shape[3], overlap_size): - row = [] - for j in range(0, x.shape[4], overlap_size): - tile = x[ - :, - :, - :, - i : i + self.tile_sample_min_size, - j : j + self.tile_sample_min_size, - ] - tile = self.encoder(tile) - tile = self.quant_conv(tile) - row.append(tile) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - moments = torch.cat(result_rows, dim=3) - return moments - - def blend_z( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[2], b.shape[2], blend_extent) - for z in range(blend_extent): - b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( - 1 - z / blend_extent - ) + b[:, :, z, :, :] * (z / blend_extent) - return b - - def blend_v( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[3], b.shape[3], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( - 1 - y / blend_extent - ) + b[:, :, :, y, :] * (y / blend_extent) - return b - - def blend_h( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: - blend_extent = min(a.shape[4], b.shape[4], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( - 1 - x / blend_extent - ) + b[:, :, :, :, x] * (x / blend_extent) - return b - - def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent - tile_target_shape = ( - *target_shape[:3], - self.tile_sample_min_size, - self.tile_sample_min_size, - ) - # Split z into overlapping 64x64 tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, z.shape[3], overlap_size): - row = [] - for j in range(0, z.shape[4], overlap_size): - tile = z[ - :, - :, - :, - i : i + self.tile_latent_min_size, - j : j + self.tile_latent_min_size, - ] - tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, target_shape=tile_target_shape) - row.append(decoded) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - return dec - - def encode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - num_splits = z.shape[2] // self.z_sample_size - sizes = [self.z_sample_size] * num_splits - sizes = ( - sizes + [z.shape[2] - sum(sizes)] - if z.shape[2] - sum(sizes) > 0 - else sizes - ) - tiles = z.split(sizes, dim=2) - moments_tiles = [ - ( - self._hw_tiled_encode(z_tile, return_dict) - if self.use_hw_tiling - else self._encode(z_tile) - ) - for z_tile in tiles - ] - moments = torch.cat(moments_tiles, dim=2) - - else: - moments = ( - self._hw_tiled_encode(z, return_dict) - if self.use_hw_tiling - else self._encode(z) - ) - - posterior = DiagonalGaussianDistribution(moments) - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - _, c, _, _, _ = z.shape - z = torch.cat( - [ - self.latent_norm_out(z[:, : c // 2, :, :, :]), - z[:, c // 2 :, :, :, :], - ], - dim=1, - ) - elif isinstance(self.latent_norm_out, nn.BatchNorm2d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(self.latent_norm_out, nn.BatchNorm3d): - running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) - running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) - eps = self.latent_norm_out.eps - - z = z * torch.sqrt(running_var + eps) + running_mean - elif isinstance(self.latent_norm_out, nn.BatchNorm3d): - raise NotImplementedError("BatchNorm2d not supported") - return z - - def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - moments = self._normalize_latent_channels(moments) - return moments - - def _decode( - self, - z: torch.FloatTensor, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - z = self._unnormalize_latent_channels(z) - z = self.post_quant_conv(z) - if "timestep" in self.decoder_params: - dec = self.decoder(z, target_shape=target_shape, timestep=timestep) - else: - dec = self.decoder(z, target_shape=target_shape) - return dec - - def decode( - self, - z: torch.FloatTensor, - return_dict: bool = True, - target_shape=None, - timestep: Optional[torch.Tensor] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - assert target_shape is not None, "target_shape must be provided for decoding" - if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: - reduction_factor = int( - self.encoder.patch_size_t - * 2 - ** ( - len(self.encoder.down_blocks) - - 1 - - math.sqrt(self.encoder.patch_size) - ) - ) - split_size = self.z_sample_size // reduction_factor - num_splits = z.shape[2] // split_size - - # copy target shape, and divide frame dimension (=2) by the context size - target_shape_split = list(target_shape) - target_shape_split[2] = target_shape[2] // num_splits - - decoded_tiles = [ - ( - self._hw_tiled_decode(z_tile, target_shape_split) - if self.use_hw_tiling - else self._decode(z_tile, target_shape=target_shape_split) - ) - for z_tile in torch.tensor_split(z, num_splits, dim=2) - ] - decoded = torch.cat(decoded_tiles, dim=2) - else: - decoded = ( - self._hw_tiled_decode(z, target_shape) - if self.use_hw_tiling - else self._decode(z, target_shape=target_shape, timestep=timestep) - ) - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`DecoderOutput`] instead of a plain tuple. - generator (`torch.Generator`, *optional*): - Generator used to sample from the posterior. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z, target_shape=sample.shape).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py b/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py deleted file mode 100644 index 5a0aeeccf..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/vae_encode.py +++ /dev/null @@ -1,247 +0,0 @@ -from typing import Tuple -import torch -from diffusers import AutoencoderKL -from einops import rearrange -from torch import Tensor - - -from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( - CausalVideoAutoencoder, -) -from maxdiffusion.models.ltx_video.autoencoders.video_autoencoder import ( - Downsample3D, - VideoAutoencoder, -) - -try: - import torch_xla.core.xla_model as xm -except ImportError: - xm = None - - -def vae_encode( - media_items: Tensor, - vae: AutoencoderKL, - split_size: int = 1, - vae_per_channel_normalize=False, -) -> Tensor: - """ - Encodes media items (images or videos) into latent representations using a specified VAE model. - The function supports processing batches of images or video frames and can handle the processing - in smaller sub-batches if needed. - - Args: - media_items (Tensor): A torch Tensor containing the media items to encode. The expected - shape is (batch_size, channels, height, width) for images or (batch_size, channels, - frames, height, width) for videos. - vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, - pre-configured and loaded with the appropriate model weights. - split_size (int, optional): The number of sub-batches to split the input batch into for encoding. - If set to more than 1, the input media items are processed in smaller batches according to - this value. Defaults to 1, which processes all items in a single batch. - - Returns: - Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted - to match the input shape, scaled by the model's configuration. - - Examples: - >>> import torch - >>> from diffusers import AutoencoderKL - >>> vae = AutoencoderKL.from_pretrained('your-model-name') - >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. - >>> latents = vae_encode(images, vae) - >>> print(latents.shape) # Output shape will depend on the model's latent configuration. - - Note: - In case of a video, the function encodes the media item frame-by frame. - """ - is_video_shaped = media_items.dim() == 5 - batch_size, channels = media_items.shape[0:2] - - if channels != 3: - raise ValueError(f"Expects tensors with 3 channels, got {channels}.") - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - media_items = rearrange(media_items, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(media_items) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(media_items) // split_size - # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] - latents = [] - if media_items.device.type == "xla": - xm.mark_step() - for image_batch in media_items.split(encode_bs): - latents.append(vae.encode(image_batch).latent_dist.sample()) - if media_items.device.type == "xla": - xm.mark_step() - latents = torch.cat(latents, dim=0) - else: - latents = vae.encode(media_items).latent_dist.sample() - - latents = normalize_latents(latents, vae, vae_per_channel_normalize) - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) - return latents - - -def vae_decode( - latents: Tensor, - vae: AutoencoderKL, - is_video: bool = True, - split_size: int = 1, - vae_per_channel_normalize=False, - timestep=None, -) -> Tensor: - is_video_shaped = latents.dim() == 5 - batch_size = latents.shape[0] - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - latents = rearrange(latents, "b c n h w -> (b n) c h w") - if split_size > 1: - if len(latents) % split_size != 0: - raise ValueError( - "Error: The batch size must be divisible by 'train.vae_bs_split" - ) - encode_bs = len(latents) // split_size - image_batch = [ - _run_decoder( - latent_batch, vae, is_video, vae_per_channel_normalize, timestep - ) - for latent_batch in latents.split(encode_bs) - ] - images = torch.cat(image_batch, dim=0) - else: - images = _run_decoder( - latents, vae, is_video, vae_per_channel_normalize, timestep - ) - - if is_video_shaped and not isinstance( - vae, (VideoAutoencoder, CausalVideoAutoencoder) - ): - images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) - return images - - -def _run_decoder( - latents: Tensor, - vae: AutoencoderKL, - is_video: bool, - vae_per_channel_normalize=False, - timestep=None, -) -> Tensor: - if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): - *_, fl, hl, wl = latents.shape - temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) - latents = latents.to(vae.dtype) - vae_decode_kwargs = {} - if timestep is not None: - vae_decode_kwargs["timestep"] = timestep - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - target_shape=( - 1, - 3, - fl * temporal_scale if is_video else 1, - hl * spatial_scale, - wl * spatial_scale, - ), - **vae_decode_kwargs, - )[0] - else: - image = vae.decode( - un_normalize_latents(latents, vae, vae_per_channel_normalize), - return_dict=False, - )[0] - return image - - -def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: - if isinstance(vae, CausalVideoAutoencoder): - spatial = vae.spatial_downscale_factor - temporal = vae.temporal_downscale_factor - else: - down_blocks = len( - [ - block - for block in vae.encoder.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - spatial = vae.config.patch_size * 2**down_blocks - temporal = ( - vae.config.patch_size_t * 2**down_blocks - if isinstance(vae, VideoAutoencoder) - else 1 - ) - - return (temporal, spatial, spatial) - - -def latent_to_pixel_coords( - latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False -) -> Tensor: - """ - Converts latent coordinates to pixel coordinates by scaling them according to the VAE's - configuration. - - Args: - latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] - containing the latent corner coordinates of each token. - vae (AutoencoderKL): The VAE model - causal_fix (bool): Whether to take into account the different temporal scale - of the first frame. Default = False for backwards compatibility. - Returns: - Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. - """ - - scale_factors = get_vae_size_scale_factor(vae) - causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix - pixel_coords = latent_to_pixel_coords_from_factors( - latent_coords, scale_factors, causal_fix - ) - return pixel_coords - - -def latent_to_pixel_coords_from_factors( - latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False -) -> Tensor: - pixel_coords = ( - latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] - ) - if causal_fix: - # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) - return pixel_coords - - -def normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) - / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents * vae.config.scaling_factor - ) - - -def un_normalize_latents( - latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False -) -> Tensor: - return ( - latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) - if vae_per_channel_normalize - else latents / vae.config.scaling_factor - ) diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py b/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py deleted file mode 100644 index 5b9ea640b..000000000 --- a/src/maxdiffusion/models/ltx_video/autoencoders/video_autoencoder.py +++ /dev/null @@ -1,1045 +0,0 @@ -import json -import os -from functools import partial -from types import SimpleNamespace -from typing import Any, Mapping, Optional, Tuple, Union - -import torch -from einops import rearrange -from torch import nn -from torch.nn import functional - -from diffusers.utils import logging - -from maxdiffusion.models.ltx_video.utils.torch_utils import Identity -from maxdiffusion.models.ltx_video.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd -from maxdiffusion.models.ltx_video.autoencoders.pixel_norm import PixelNorm -from maxdiffusion.models.ltx_video.autoencoders.vae import AutoencoderKLWrapper - -logger = logging.get_logger(__name__) - - -class VideoAutoencoder(AutoencoderKLWrapper): - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *args, - **kwargs, - ): - config_local_path = pretrained_model_name_or_path / "config.json" - config = cls.load_config(config_local_path, **kwargs) - video_vae = cls.from_config(config) - video_vae.to(kwargs["torch_dtype"]) - - model_local_path = pretrained_model_name_or_path / "autoencoder.pth" - ckpt_state_dict = torch.load(model_local_path) - video_vae.load_state_dict(ckpt_state_dict) - - statistics_local_path = ( - pretrained_model_name_or_path / "per_channel_statistics.json" - ) - if statistics_local_path.exists(): - with open(statistics_local_path, "r") as file: - data = json.load(file) - transposed_data = list(zip(*data["data"])) - data_dict = { - col: torch.tensor(vals) - for col, vals in zip(data["columns"], transposed_data) - } - video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) - video_vae.register_buffer( - "mean_of_means", - data_dict.get( - "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) - ), - ) - - return video_vae - - @staticmethod - def from_config(config): - assert ( - config["_class_name"] == "VideoAutoencoder" - ), "config must have _class_name=VideoAutoencoder" - if isinstance(config["dims"], list): - config["dims"] = tuple(config["dims"]) - - assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - use_quant_conv = config.get("use_quant_conv", True) - - if use_quant_conv and latent_log_var == "uniform": - raise ValueError("uniform latent_log_var requires use_quant_conv=False") - - encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - block_out_channels=config["block_out_channels"], - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), - add_channel_padding=config.get("add_channel_padding", False), - ) - - dims = config["dims"] - return VideoAutoencoder( - encoder=encoder, - decoder=decoder, - latent_channels=config["latent_channels"], - dims=dims, - use_quant_conv=use_quant_conv, - ) - - @property - def config(self): - return SimpleNamespace( - _class_name="VideoAutoencoder", - dims=self.dims, - in_channels=self.encoder.conv_in.in_channels - // (self.encoder.patch_size_t * self.encoder.patch_size**2), - out_channels=self.decoder.conv_out.out_channels - // (self.decoder.patch_size_t * self.decoder.patch_size**2), - latent_channels=self.decoder.conv_in.in_channels, - block_out_channels=[ - self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels - for i in range(len(self.encoder.down_blocks)) - ], - scaling_factor=1.0, - norm_layer=self.encoder.norm_layer, - patch_size=self.encoder.patch_size, - latent_log_var=self.encoder.latent_log_var, - use_quant_conv=self.use_quant_conv, - patch_size_t=self.encoder.patch_size_t, - add_channel_padding=self.encoder.add_channel_padding, - ) - - @property - def is_video_supported(self): - """ - Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. - """ - return self.dims != 2 - - @property - def downscale_factor(self): - return self.encoder.downsample_factor - - def to_json_string(self) -> str: - import json - - return json.dumps(self.config.__dict__) - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - model_keys = set(name for name, _ in self.named_parameters()) - - key_mapping = { - ".resnets.": ".res_blocks.", - "downsamplers.0": "downsample", - "upsamplers.0": "upsample", - } - - converted_state_dict = {} - for key, value in state_dict.items(): - for k, v in key_mapping.items(): - key = key.replace(k, v) - - if "norm" in key and key not in model_keys: - logger.info( - f"Removing key {key} from state_dict as it is not present in the model" - ) - continue - - converted_state_dict[key] = value - - super().load_state_dict(converted_state_dict, strict=strict) - - def last_layer(self): - if hasattr(self.decoder, "conv_out"): - if isinstance(self.decoder.conv_out, nn.Sequential): - last_layer = self.decoder.conv_out[-1] - else: - last_layer = self.decoder.conv_out - else: - last_layer = self.decoder.layers[-1] - return last_layer - - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]] = 3, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: Union[int, Tuple[int]] = 1, - norm_layer: str = "group_norm", # group_norm, pixel_norm - latent_log_var: str = "per_channel", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var - if add_channel_padding: - in_channels = in_channels * self.patch_size**3 - else: - in_channels = in_channels * self.patch_size_t * self.patch_size**2 - self.in_channels = in_channels - output_channel = block_out_channels[0] - - self.conv_in = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - padding=1, - ) - - self.down_blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels)): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - dims=dims, - in_channels=input_channel, - out_channels=output_channel, - num_layers=self.layers_per_block, - add_downsample=not is_final_block and 2**i >= patch_size, - resnet_eps=1e-6, - downsample_padding=0, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.down_blocks.append(down_block) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - # out - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], - num_groups=norm_num_groups, - eps=1e-6, - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - self.conv_act = nn.SiLU() - - conv_out_channels = out_channels - if latent_log_var == "per_channel": - conv_out_channels *= 2 - elif latent_log_var == "uniform": - conv_out_channels += 1 - elif latent_log_var != "none": - raise ValueError(f"Invalid latent_log_var: {latent_log_var}") - self.conv_out = make_conv_nd( - dims, block_out_channels[-1], conv_out_channels, 3, padding=1 - ) - - self.gradient_checkpointing = False - - @property - def downscale_factor(self): - return ( - 2 - ** len( - [ - block - for block in self.down_blocks - if isinstance(block.downsample, Downsample3D) - ] - ) - * self.patch_size - ) - - def forward( - self, sample: torch.FloatTensor, return_features=False - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - downsample_in_time = sample.shape[2] != 1 - - # patchify - patch_size_t = self.patch_size_t if downsample_in_time else 1 - sample = patchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) - - sample = self.conv_in(sample) - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - if return_features: - features = [] - for down_block in self.down_blocks: - sample = checkpoint_fn(down_block)( - sample, downsample_in_time=downsample_in_time - ) - if return_features: - features.append(sample) - - sample = checkpoint_fn(self.mid_block)(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if self.latent_log_var == "uniform": - last_channel = sample[:, -1:, ...] - num_dims = sample.dim() - - if num_dims == 4: - # For shape (B, C, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - elif num_dims == 5: - # For shape (B, C, F, H, W) - repeated_last_channel = last_channel.repeat( - 1, sample.shape[1] - 2, 1, 1, 1 - ) - sample = torch.cat([sample, repeated_last_channel], dim=1) - else: - raise ValueError(f"Invalid input shape: {sample.shape}") - - if return_features: - features.append(sample[:, : self.latent_channels, ...]) - return sample, features - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - patch_size (`int`, *optional*, defaults to 1): - The patch size to use. Should be a power of 2. - norm_layer (`str`, *optional*, defaults to `group_norm`): - The normalization layer to use. Can be either `group_norm` or `pixel_norm`. - """ - - def __init__( - self, - dims, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - patch_size: int = 1, - norm_layer: str = "group_norm", - patch_size_t: Optional[int] = None, - add_channel_padding: Optional[bool] = False, - ): - super().__init__() - self.patch_size = patch_size - self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size - self.add_channel_padding = add_channel_padding - self.layers_per_block = layers_per_block - if add_channel_padding: - out_channels = out_channels * self.patch_size**3 - else: - out_channels = out_channels * self.patch_size_t * self.patch_size**2 - self.out_channels = out_channels - - self.conv_in = make_conv_nd( - dims, - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - self.mid_block = UNetMidBlock3D( - dims=dims, - in_channels=block_out_channels[-1], - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - dims=dims, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block - and 2 ** (len(block_out_channels) - i - 1) > patch_size, - resnet_eps=1e-6, - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - ) - self.up_blocks.append(up_block) - - if norm_layer == "group_norm": - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - elif norm_layer == "pixel_norm": - self.conv_norm_out = PixelNorm() - - self.conv_act = nn.SiLU() - self.conv_out = make_conv_nd( - dims, block_out_channels[0], out_channels, 3, padding=1 - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - assert target_shape is not None, "target_shape must be provided" - upsample_in_time = sample.shape[2] < target_shape[2] - - sample = self.conv_in(sample) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - - checkpoint_fn = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - if self.gradient_checkpointing and self.training - else lambda x: x - ) - - sample = checkpoint_fn(self.mid_block)(sample) - sample = sample.to(upscale_dtype) - - for up_block in self.up_blocks: - sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - # un-patchify - patch_size_t = self.patch_size_t if upsample_in_time else 1 - sample = unpatchify( - sample, - patch_size_hw=self.patch_size, - patch_size_t=patch_size_t, - add_channel_padding=self.add_channel_padding, - ) - - return sample - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_downsample: bool = True, - downsample_padding: int = 1, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - - if add_downsample: - self.downsample = Downsample3D( - dims, - out_channels, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsample = Identity() - - def forward( - self, hidden_states: torch.FloatTensor, downsample_in_time - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - hidden_states = self.downsample( - hidden_states, downsample_in_time=downsample_in_time - ) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - - Args: - in_channels (`int`): The number of input channels. - dropout (`float`, *optional*, defaults to 0.0): The dropout rate. - num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. - resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. - resnet_groups (`int`, *optional*, defaults to 32): - The number of groups to use in the group normalization layers of the resnet blocks. - - Returns: - `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, - in_channels, height, width)`. - - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - norm_layer: str = "group_norm", - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - self.res_blocks = nn.ModuleList( - [ - ResnetBlock3D( - dims=dims, - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_groups: int = 32, - add_upsample: bool = True, - norm_layer: str = "group_norm", - ): - super().__init__() - res_blocks = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - res_blocks.append( - ResnetBlock3D( - dims=dims, - in_channels=input_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - norm_layer=norm_layer, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - - if add_upsample: - self.upsample = Upsample3D( - dims=dims, channels=out_channels, out_channels=out_channels - ) - else: - self.upsample = Identity() - - self.resolution_idx = resolution_idx - - def forward( - self, hidden_states: torch.FloatTensor, upsample_in_time=True - ) -> torch.FloatTensor: - for resnet in self.res_blocks: - hidden_states = resnet(hidden_states) - - hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - """ - - def __init__( - self, - dims: Union[int, Tuple[int, int]], - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - groups: int = 32, - eps: float = 1e-6, - norm_layer: str = "group_norm", - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - if norm_layer == "group_norm": - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm1 = PixelNorm() - - self.non_linearity = nn.SiLU() - - self.conv1 = make_conv_nd( - dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if norm_layer == "group_norm": - self.norm2 = torch.nn.GroupNorm( - num_groups=groups, num_channels=out_channels, eps=eps, affine=True - ) - elif norm_layer == "pixel_norm": - self.norm2 = PixelNorm() - - self.dropout = torch.nn.Dropout(dropout) - - self.conv2 = make_conv_nd( - dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - self.conv_shortcut = ( - make_linear_nd( - dims=dims, in_channels=in_channels, out_channels=out_channels - ) - if in_channels != out_channels - else nn.Identity() - ) - - def forward( - self, - input_tensor: torch.FloatTensor, - ) -> torch.FloatTensor: - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - - hidden_states = self.norm2(hidden_states) - - hidden_states = self.non_linearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - - hidden_states = self.conv2(hidden_states) - - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -class Downsample3D(nn.Module): - def __init__( - self, - dims, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - padding: int = 1, - ): - super().__init__() - stride: int = 2 - self.padding = padding - self.in_channels = in_channels - self.dims = dims - self.conv = make_conv_nd( - dims=dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - - def forward(self, x, downsample_in_time=True): - conv = self.conv - if self.padding == 0: - if self.dims == 2: - padding = (0, 1, 0, 1) - else: - padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) - - x = functional.pad(x, padding, mode="constant", value=0) - - if self.dims == (2, 1) and not downsample_in_time: - return conv(x, skip_time_conv=True) - - return conv(x) - - -class Upsample3D(nn.Module): - """ - An upsampling layer for 3D tensors of shape (B, C, D, H, W). - - :param channels: channels in the inputs and outputs. - """ - - def __init__(self, dims, channels, out_channels=None): - super().__init__() - self.dims = dims - self.channels = channels - self.out_channels = out_channels or channels - self.conv = make_conv_nd( - dims, channels, out_channels, kernel_size=3, padding=1, bias=True - ) - - def forward(self, x, upsample_in_time): - if self.dims == 2: - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - else: - time_scale_factor = 2 if upsample_in_time else 1 - # print("before:", x.shape) - b, c, d, h, w = x.shape - x = rearrange(x, "b c d h w -> (b d) c h w") - # height and width interpolate - x = functional.interpolate( - x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" - ) - _, _, h, w = x.shape - - if not upsample_in_time and self.dims == (2, 1): - x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) - return self.conv(x, skip_time_conv=True) - - # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension - x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) - - # (b h w) c 1 d - new_d = x.shape[-1] * time_scale_factor - x = functional.interpolate(x, (1, new_d), mode="nearest") - # (b h w) c 1 new_d - x = rearrange( - x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d - ) - # b c d h w - - # x = functional.interpolate( - # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - # ) - # print("after:", x.shape) - - return self.conv(x) - - -def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b c (f p) (h q) (w r) -> b (c p r q) f h w", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] - padding_zeros = torch.zeros( - x.shape[0], - channels_to_pad, - x.shape[2], - x.shape[3], - x.shape[4], - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([padding_zeros, x], dim=1) - - return x - - -def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): - if patch_size_hw == 1 and patch_size_t == 1: - return x - - if ( - (x.dim() == 5) - and (patch_size_hw > patch_size_t) - and (patch_size_t > 1 or add_channel_padding) - ): - channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) - x = x[:, :channels_to_keep, :, :, :] - - if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw - ) - elif x.dim() == 5: - x = rearrange( - x, - "b (c p r q) f h w -> b c (f p) (h q) (w r)", - p=patch_size_t, - q=patch_size_hw, - r=patch_size_hw, - ) - - return x - - -def create_video_autoencoder_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [ - 128, - 256, - 512, - 512, - ], # Number of output channels of each encoder / decoder inner block - "patch_size": 1, - } - - return config - - -def create_video_autoencoder_pathify4x4x4_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": ( - 2, - 1, - ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "latent_log_var": "uniform", - } - - return config - - -def create_video_autoencoder_pathify4x4_config( - latent_channels: int = 4, -): - config = { - "_class_name": "VideoAutoencoder", - "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d - "in_channels": 3, # Number of input color channels (e.g., RGB) - "out_channels": 3, # Number of output color channels - "latent_channels": latent_channels, # Number of channels in the latent space representation - "block_out_channels": [512] - * 4, # Number of output channels of each encoder / decoder inner block - "patch_size": 4, - "norm_layer": "pixel_norm", - } - - return config - - -def test_vae_patchify_unpatchify(): - import torch - - x = torch.randn(2, 3, 8, 64, 64) - x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) - x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) - assert torch.allclose(x, x_unpatched) - - -def demo_video_autoencoder_forward_backward(): - # Configuration for the VideoAutoencoder - config = create_video_autoencoder_pathify4x4x4_config() - - # Instantiate the VideoAutoencoder with the specified configuration - video_autoencoder = VideoAutoencoder.from_config(config) - - print(video_autoencoder) - - # Print the total number of parameters in the video autoencoder - total_params = sum(p.numel() for p in video_autoencoder.parameters()) - print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") - - # Create a mock input tensor simulating a batch of videos - # Shape: (batch_size, channels, depth, height, width) - # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame - input_videos = torch.randn(2, 3, 8, 64, 64) - - # Forward pass: encode and decode the input videos - latent = video_autoencoder.encode(input_videos).latent_dist.mode() - print(f"input shape={input_videos.shape}") - print(f"latent shape={latent.shape}") - reconstructed_videos = video_autoencoder.decode( - latent, target_shape=input_videos.shape - ).sample - - print(f"reconstructed shape={reconstructed_videos.shape}") - - # Calculate the loss (e.g., mean squared error) - loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) - - # Perform backward pass - loss.backward() - - print(f"Demo completed with loss: {loss.item()}") - - -# Ensure to call the demo function to execute the forward and backward pass -if __name__ == "__main__": - demo_video_autoencoder_forward_backward() diff --git a/src/maxdiffusion/models/ltx_video/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py similarity index 100% rename from src/maxdiffusion/models/ltx_video/autoencoders/__init__.py rename to src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 06b367ce4..31c6b5b15 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -8,109 +8,109 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + @nn.compact + def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: + data_input, index_input = carry - @nn.compact - def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tuple[jax.Array, jax.Array], None]: - data_input, index_input = carry + mod = self.module(*self.module_init_args, **self.module_init_kwargs) - mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers - output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers + next_index = index_input + 1 + new_carry = (output_data, next_index) - next_index = index_input + 1 - new_carry = (output_data, next_index) - + return new_carry, None - return new_carry, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - if not args: - raise ValueError("RepeatableLayer expects at least one argument for initial data input.") + @nn.compact + def __call__(self, *args): + if not args: + raise ValueError("RepeatableLayer expects at least one argument for initial data input.") - initial_data_input = args[0] - static_block_args = args[1:] + initial_data_input = args[0] + static_block_args = args[1:] - initial_index = jnp.array(0, dtype=jnp.int32) #index of current transformer block + initial_index = jnp.array(0, dtype=jnp.int32) # index of current transformer block - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) - in_axes_for_scan = (nn.broadcast,) * (len(args)-1) + in_axes_for_scan = (nn.broadcast,) * (len(args) - 1) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={"params": True}, - in_axes=in_axes_for_scan, - length=self.num_layers, - **scan_kwargs, - ) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True}, + in_axes=in_axes_for_scan, + length=self.num_layers, + **scan_kwargs, + ) - wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) - # Call wrapped_function with the initial carry tuple and the static_block_args - (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) + # Call wrapped_function with the initial carry tuple and the static_block_args + (final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args) - return final_data \ No newline at end of file + return final_data diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index caff27add..75692b703 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -15,7 +15,6 @@ # This implementation is based on the Torch version available at: # https://github.com/Lightricks/LTX-Video/tree/main from functools import partial -import functools import math from typing import Any, Dict, Optional, Tuple from enum import Enum, auto @@ -40,899 +39,885 @@ ) - class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - def __call__(self, x): - return x + + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + index: int, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + skip_layer_strategy = SkipLayerStrategy.AttentionValues + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + + attn_output = self.attn1( + norm_hidden_states, + block_index=index, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + block_index=-1, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, + name="to_out.0", matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - - def __call__( - self, - index: int, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - skip_layer_strategy = SkipLayerStrategy.AttentionValues - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + block_index: int = -1, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} #noqa: F821 + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape( + skip_layer_mask[block_index], (batch_size, 1, 1) + ) # here skip_layer_mask is (48,3), changed this currently! + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + elif skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues: + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - - attn_output = self.attn1( - norm_hidden_states, - block_index = index, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) - attn_output = nn.with_logical_constraint( - attn_output, ("activation_batch", "activation_norm_length", "activation_embed") - ) +class AttentionOp(nn.Module): - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint( - attn_input, ("activation_batch", "activation_norm_length", "activation_embed") - ) - attn_output = self.attn2( - attn_input, - block_index = -1, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") - ) + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkvo_sharding_spec = jax.sharding.PartitionSpec( + "data", + "fsdp", + None, + "tensor", + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint( - ff_output, ("activation_batch", "activation_norm_length", "activation_embed") - ) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="to_out.0", - matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any(device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print("Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - block_index: int = -1, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } - assert cross_attention_kwargs.get("scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask[block_index], (batch_size, 1, 1)) #here skip_layer_mask is (48,3), changed this currently! - - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - value_for_stg = value - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") - - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") + ** SRAM cache size for TPU + V5P - 1MB SRAM per core - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") - - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) - hidden_states_a = self.attention( - query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype - ) +class ExplicitAttention(nn.Module): - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") - ) + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v - hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) - elif ( - skip_layer_mask is not None - and skip_layer_strategy == SkipLayerStrategy.AttentionValues - ): - hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( - 1.0 - skip_layer_mask - ) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. -class AttentionOp(nn.Module): - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - # qkvo_sharding_spec = jax.sharding.PartitionSpec( - # ("data", "fsdp", "fsdp_transpose", "expert"), - # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - # None, - # None, - # ) - # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - qkvo_sharding_spec = jax.sharding.PartitionSpec( - "data", - "fsdp", - None, - "tensor", - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. - - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can - - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM - - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size - - ** SRAM cache size for TPU - V5P - 1MB SRAM per core - - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used - - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + Returns: + jax.Array: Normed data + """ -class ExplicitAttention(nn.Module): - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. - - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. - - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data - - Returns: - jax.Array: Normed data - """ - - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) - - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError(f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out \ No newline at end of file + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py deleted file mode 100644 index 2eca32033..000000000 --- a/src/maxdiffusion/models/ltx_video/transformers/symmetric_patchifier.py +++ /dev/null @@ -1,84 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Tuple - -import torch -from diffusers.configuration_utils import ConfigMixin -from einops import rearrange -from torch import Tensor - - -class Patchifier(ConfigMixin, ABC): - def __init__(self, patch_size: int): - super().__init__() - self._patch_size = (1, patch_size, patch_size) - - @abstractmethod - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - raise NotImplementedError("Patchify method not implemented") - - @abstractmethod - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - pass - - @property - def patch_size(self): - return self._patch_size - - def get_latent_coords( - self, latent_num_frames, latent_height, latent_width, batch_size, device - ): - """ - Return a tensor of shape [batch_size, 3, num_patches] containing the - top-left corner latent coordinates of each latent patch. - The tensor is repeated for each batch element. - """ - latent_sample_coords = torch.meshgrid( - torch.arange(0, latent_num_frames, self._patch_size[0], device=device), - torch.arange(0, latent_height, self._patch_size[1], device=device), - torch.arange(0, latent_width, self._patch_size[2], device=device), - ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size - ) - return latent_coords - - -class SymmetricPatchifier(Patchifier): - def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: - b, _, f, h, w = latents.shape - latent_coords = self.get_latent_coords(f, h, w, b, latents.device) - latents = rearrange( - latents, - "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", - p1=self._patch_size[0], - p2=self._patch_size[1], - p3=self._patch_size[2], - ) - return latents, latent_coords - - def unpatchify( - self, - latents: Tensor, - output_height: int, - output_width: int, - out_channels: int, - ) -> Tuple[Tensor, Tensor]: - output_height = output_height // self._patch_size[1] - output_width = output_width // self._patch_size[2] - latents = rearrange( - latents, - "b (f h w) (c p q) -> b c f (h p) (w q)", - h=output_height, - w=output_width, - p=self._patch_size[1], - q=self._patch_size[2], - ) - return latents diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index e3110a82b..1c1807fdd 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -29,304 +29,298 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - def init_weights(self, in_channels, key, caption_channels, eval_only=True): - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "hidden_states": (batch_size, num_tokens, in_channels), - "indices_grid": (batch_size, 3, num_tokens), - "encoder_hidden_states": (batch_size, 128, caption_channels), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - if eval_only: - return jax.eval_shape( - self.init, - key, - **example_inputs, - )["params"] - else: - return self.init(key, **example_inputs)["params"] - def create_skip_layer_mask( - self, - batch_size: int, - num_conds: int, - ptb_index: int, - skip_block_list: Optional[List[int]] = None, - ) -> Optional[jnp.ndarray]: - if skip_block_list is None or len(skip_block_list) == 0: - return None - mask = jnp.ones( - (self.num_layers, batch_size * num_conds), dtype=self.dtype - ) - - for block_idx in skip_block_list: - mask = mask.at[block_idx, ptb_index::num_conds].set(0) - - return mask - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - skip_layer_mask=None, - skip_layer_strategy=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - - if self.num_layers > 0: - - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - skip_layer_mask, - skip_layer_strategy, - ) - # Output processing - - scale_shift_values = ( - self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] - ) - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( #noqa: C408 + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + key, + **example_inputs, + )["params"] + else: + return self.init(key, **example_inputs)["params"] + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ) -> Optional[jnp.ndarray]: + if skip_block_list is None or len(skip_block_list) == 0: + return None + mask = jnp.ones((self.num_layers, batch_size * num_conds), dtype=self.dtype) + + for block_idx in skip_block_list: + mask = mask.at[block_idx, ptb_index::num_conds].set(0) + + return mask + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + skip_layer_mask=None, + skip_layer_strategy=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask, + skip_layer_strategy, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. - - Args: - x (jax.Array): log value - base (jax.Array): base of the log - - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) - + """ + Computes log of x with defined base. + Args: + x (jax.Array): log value + base (jax.Array): base of the log + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - dtype = jnp.float32 # We need full precision in the freqs_cis computation. - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py deleted file mode 100644 index 53c0082d1..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py +++ /dev/null @@ -1,174 +0,0 @@ -def make_hashable_key(dict_key): - def convert_value(value): - if isinstance(value, list): - return tuple(value) - elif isinstance(value, dict): - return tuple(sorted((k, convert_value(v)) for k, v in value.items())) - else: - return value - - return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) - - -DIFFUSERS_SCHEDULER_CONFIG = { - "_class_name": "FlowMatchEulerDiscreteScheduler", - "_diffusers_version": "0.32.0.dev0", - "base_image_seq_len": 1024, - "base_shift": 0.95, - "invert_sigmas": False, - "max_image_seq_len": 4096, - "max_shift": 2.05, - "num_train_timesteps": 1000, - "shift": 1.0, - "shift_terminal": 0.1, - "use_beta_sigmas": False, - "use_dynamic_shifting": True, - "use_exponential_sigmas": False, - "use_karras_sigmas": False, -} -DIFFUSERS_TRANSFORMER_CONFIG = { - "_class_name": "LTXVideoTransformer3DModel", - "_diffusers_version": "0.32.0.dev0", - "activation_fn": "gelu-approximate", - "attention_bias": True, - "attention_head_dim": 64, - "attention_out_bias": True, - "caption_channels": 4096, - "cross_attention_dim": 2048, - "in_channels": 128, - "norm_elementwise_affine": False, - "norm_eps": 1e-06, - "num_attention_heads": 32, - "num_layers": 28, - "out_channels": 128, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm_across_heads", -} -DIFFUSERS_VAE_CONFIG = { - "_class_name": "AutoencoderKLLTXVideo", - "_diffusers_version": "0.32.0.dev0", - "block_out_channels": [128, 256, 512, 512], - "decoder_causal": False, - "encoder_causal": True, - "in_channels": 3, - "latent_channels": 128, - "layers_per_block": [4, 3, 3, 3, 4], - "out_channels": 3, - "patch_size": 4, - "patch_size_t": 1, - "resnet_norm_eps": 1e-06, - "scaling_factor": 1.0, - "spatio_temporal_scaling": [True, True, True, False], -} - -OURS_SCHEDULER_CONFIG = { - "_class_name": "RectifiedFlowScheduler", - "_diffusers_version": "0.25.1", - "num_train_timesteps": 1000, - "shifting": "SD3", - "base_resolution": None, - "target_shift_terminal": 0.1, -} - -OURS_TRANSFORMER_CONFIG = { - "_class_name": "Transformer3DModel", - "_diffusers_version": "0.25.1", - "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", - "activation_fn": "gelu-approximate", - "attention_bias": True, - "attention_head_dim": 64, - "attention_type": "default", - "caption_channels": 4096, - "cross_attention_dim": 2048, - "double_self_attention": False, - "dropout": 0.0, - "in_channels": 128, - "norm_elementwise_affine": False, - "norm_eps": 1e-06, - "norm_num_groups": 32, - "num_attention_heads": 32, - "num_embeds_ada_norm": 1000, - "num_layers": 28, - "num_vector_embeds": None, - "only_cross_attention": False, - "out_channels": 128, - "project_to_2d_pos": True, - "upcast_attention": False, - "use_linear_projection": False, - "qk_norm": "rms_norm", - "standardization_norm": "rms_norm", - "positional_embedding_type": "rope", - "positional_embedding_theta": 10000.0, - "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000, -} -OURS_VAE_CONFIG = { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "in_channels": 3, - "out_channels": 3, - "latent_channels": 128, - "blocks": [ - ["res_x", 4], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x", 3], - ["res_x", 4], - ], - "scaling_factor": 1.0, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, -} - - -diffusers_and_ours_config_mapping = { - make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, - make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, - make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, -} - - -TRANSFORMER_KEYS_RENAME_DICT = { - "proj_in": "patchify_proj", - "time_embed": "adaln_single", - "norm_q": "q_norm", - "norm_k": "k_norm", -} - - -VAE_KEYS_RENAME_DICT = { - "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", - "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", - "decoder.up_blocks.3": "decoder.up_blocks.9", - "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", - "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", - "decoder.up_blocks.2": "decoder.up_blocks.6", - "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", - "decoder.up_blocks.1": "decoder.up_blocks.3", - "decoder.up_blocks.0": "decoder.up_blocks.1", - "decoder.mid_block": "decoder.up_blocks.0", - "encoder.down_blocks.3": "encoder.down_blocks.8", - "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", - "encoder.down_blocks.2": "encoder.down_blocks.6", - "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", - "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", - "encoder.down_blocks.1": "encoder.down_blocks.3", - "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", - "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", - "encoder.down_blocks.0": "encoder.down_blocks.0", - "encoder.mid_block": "encoder.down_blocks.9", - "conv_shortcut.conv": "conv_shortcut", - "resnets": "res_blocks", - "norm3": "norm3.norm", - "latents_mean": "per_channel_statistics.mean-of-means", - "latents_std": "per_channel_statistics.std-of-means", -} diff --git a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py b/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py deleted file mode 100644 index 901051728..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/prompt_enhance_utils.py +++ /dev/null @@ -1,226 +0,0 @@ -import logging -from typing import Union, List, Optional - -import torch -from PIL import Image - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. -Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. -Start directly with the action, and keep descriptions literal and precise. -Think like a cinematographer describing a shot list. -Do not change the user input intent, just enhance it. -Keep within 150 words. -For best results, build your prompts using this structure: -Start with main action in a single sentence -Add specific details about movements and gestures -Describe character/object appearances precisely -Include background and environment details -Specify camera angles and movements -Describe lighting and colors -Note any changes or sudden events -Do not exceed the 150 word limit! -Output the enhanced prompt only. -""" - -I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. -Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. -Start directly with the action, and keep descriptions literal and precise. -Think like a cinematographer describing a shot list. -Keep within 150 words. -For best results, build your prompts using this structure: -Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. -Start with main action in a single sentence -Add specific details about movements and gestures -Describe character/object appearances precisely -Include background and environment details -Specify camera angles and movements -Describe lighting and colors -Note any changes or sudden events -Align to the image caption if it contradicts the user text input. -Do not exceed the 150 word limit! -Output the enhanced prompt only. -""" - - -def tensor_to_pil(tensor): - # Ensure tensor is in range [-1, 1] - assert tensor.min() >= -1 and tensor.max() <= 1 - - # Convert from [-1, 1] to [0, 1] - tensor = (tensor + 1) / 2 - - # Rearrange from [C, H, W] to [H, W, C] - tensor = tensor.permute(1, 2, 0) - - # Convert to numpy array and then to uint8 range [0, 255] - numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") - - # Convert to PIL Image - return Image.fromarray(numpy_image) - - -def generate_cinematic_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompt: Union[str, List[str]], - conditioning_items: Optional[List] = None, - max_new_tokens: int = 256, -) -> List[str]: - prompts = [prompt] if isinstance(prompt, str) else prompt - - if conditioning_items is None: - prompts = _generate_t2v_prompt( - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - max_new_tokens, - T2V_CINEMATIC_PROMPT, - ) - else: - if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: - logger.warning( - "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" - ) - return prompts - - first_frame_conditioning_item = conditioning_items[0] - first_frames = _get_first_frames_from_conditioning_item( - first_frame_conditioning_item - ) - - assert len(first_frames) == len( - prompts - ), "Number of conditioning frames must match number of prompts" - - prompts = _generate_i2v_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts, - first_frames, - max_new_tokens, - I2V_CINEMATIC_PROMPT, - ) - - return prompts - - -def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: - frames_tensor = conditioning_item.media_item - return [ - tensor_to_pil(frames_tensor[i, :, 0, :, :]) - for i in range(frames_tensor.shape[0]) - ] - - -def _generate_t2v_prompt( - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts: List[str], - max_new_tokens: int, - system_prompt: str, -) -> List[str]: - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}"}, - ] - for p in prompts - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) - - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) - - -def _generate_i2v_prompt( - image_caption_model, - image_caption_processor, - prompt_enhancer_model, - prompt_enhancer_tokenizer, - prompts: List[str], - first_frames: List[Image.Image], - max_new_tokens: int, - system_prompt: str, -) -> List[str]: - image_captions = _generate_image_captions( - image_caption_model, image_caption_processor, first_frames - ) - - messages = [ - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, - ] - for p, c in zip(prompts, image_captions) - ] - - texts = [ - prompt_enhancer_tokenizer.apply_chat_template( - m, tokenize=False, add_generation_prompt=True - ) - for m in messages - ] - model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( - prompt_enhancer_model.device - ) - - return _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens - ) - - -def _generate_image_captions( - image_caption_model, - image_caption_processor, - images: List[Image.Image], - system_prompt: str = "", -) -> List[str]: - image_caption_prompts = [system_prompt] * len(images) - inputs = image_caption_processor( - image_caption_prompts, images, return_tensors="pt" - ).to(image_caption_model.device) - - with torch.inference_mode(): - generated_ids = image_caption_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - do_sample=False, - num_beams=3, - ) - - return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) - - -def _generate_and_decode_prompts( - prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int -) -> List[str]: - with torch.inference_mode(): - outputs = prompt_enhancer_model.generate( - **model_inputs, max_new_tokens=max_new_tokens - ) - generated_ids = [ - output_ids[len(input_ids) :] - for input_ids, output_ids in zip(model_inputs.input_ids, outputs) - ] - decoded_prompts = prompt_enhancer_tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - ) - - return decoded_prompts diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py deleted file mode 100644 index 30f9016e1..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum, auto - - -class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py b/src/maxdiffusion/models/ltx_video/utils/torch_utils.py deleted file mode 100644 index 991b07c36..000000000 --- a/src/maxdiffusion/models/ltx_video/utils/torch_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch import nn - - -def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - elif dims_to_append == 0: - return x - return x[(...,) + (None,) * dims_to_append] - - -class Identity(nn.Module): - """A placeholder identity operator that is argument-insensitive.""" - - def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument - super().__init__() - - # pylint: disable=unused-argument - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - return x diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 5584fcc0f..c4767e8ee 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -14,26 +14,26 @@ import math import os from jax import Array -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler +from maxdiffusion.models.ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler from diffusers import AutoencoderKL from typing import Optional, List, Union, Tuple from einops import rearrange import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor from transformers import ( - FlaxT5EncoderModel, - AutoTokenizer, + FlaxT5EncoderModel, AutoModelForCausalLM, AutoProcessor, - AutoTokenizer,) + AutoTokenizer, +) import json import numpy as np import torch from huggingface_hub import hf_hub_download -from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import ( +from maxdiffusion.models.ltx_video.models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) -from maxdiffusion.models.ltx_video.autoencoders.vae_encode import ( +from maxdiffusion.models.ltx_video.models.autoencoders.vae_encode import ( get_vae_size_scale_factor, latent_to_pixel_coords, vae_decode, @@ -42,935 +42,739 @@ normalize_latents, ) from diffusers.image_processor import VaeImageProcessor -from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt -from math import e from types import NoneType from typing import Any, Dict -import numpy as np import jax import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P -from typing import Optional, Union, List +from jax.sharding import Mesh from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from ...pyconfig import HyperParameters from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState -from ...max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations -) +from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import json import functools import orbax.checkpoint as ocp + def prepare_extra_step_kwargs(generator): - extra_step_kwargs = {} - extra_step_kwargs["generator"] = generator - return extra_step_kwargs + extra_step_kwargs = {} + extra_step_kwargs["generator"] = generator + return extra_step_kwargs class LTXVideoPipeline: - def __init__( - self, - transformer: Transformer3DModel, - scheduler: FlaxRectifiedFlowMultistepScheduler, - scheduler_state: RectifiedFlowSchedulerState, - vae: AutoencoderKL, - text_encoder, - patchifier, - tokenizer, + + def __init__( + self, + transformer: Transformer3DModel, + scheduler: FlaxRectifiedFlowMultistepScheduler, + scheduler_state: RectifiedFlowSchedulerState, + vae: AutoencoderKL, + text_encoder, + patchifier, + tokenizer, + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + transformer_state: Dict[Any, Any] = None, + transformer_state_shardings: Dict[Any, Any] = NoneType, + ): + self.transformer = transformer + self.devices_array = devices_array + self.mesh = mesh + self.config = config + self.p_run_inference = None + self.transformer_state = transformer_state + self.transformer_state_shardings = transformer_state_shardings + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.vae = vae + self.text_encoder = text_encoder + self.patchifier = patchifier + self.tokenizer = tokenizer + self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model + self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor + self.prompt_enhancer_llm_model = prompt_enhancer_llm_model + self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @classmethod + def load_scheduler(cls, ckpt_path, config): + if config.sampler == "from_checkpoint" or not config.sampler: + scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) + else: + scheduler = FlaxRectifiedFlowMultistepScheduler( + sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") + ) + scheduler_state = scheduler.create_state() + + return scheduler, scheduler_state + + @classmethod + def load_transformer(cls, config): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True + ) + ##load in jax weights checkpoint + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + return transformer, transformer_state, transformer_state_shardings + + @classmethod + def load_vae(cls, ckpt_path): + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + return vae + + @classmethod + def load_text_encoder(cls, ckpt_path): + t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) + return t5_encoder + + @classmethod + def load_tokenizer(cls, config, ckpt_path): + t5_tokenizer = AutoTokenizer.from_pretrained(ckpt_path, max_length=config.max_sequence_length, use_fast=True) + return t5_tokenizer + + @classmethod + def load_prompt_enhancement(cls, config): + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + config.prompt_enhancer_llm_model_name_or_path, + ) + return ( prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - transformer_state: Dict[Any, Any] = None, - transformer_state_shardings: Dict[Any, Any] = NoneType, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + transformer, transformer_state, transformer_state_shardings = cls.load_transformer(config) + + # load from pytorch version + models_dir = config.models_dir + ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) + vae = cls.load_vae(ltxv_model_path) + vae = vae.to(torch.bfloat16) + text_encoder = cls.load_text_encoder(config.text_encoder_model_name_or_path) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = cls.load_tokenizer(config, config.text_encoder_model_name_or_path) + + enhance_prompt = False + if enhance_prompt: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = cls.load_prompt_enhancement(config) + else: + ( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + ) = (None, None, None, None) + + return LTXVideoPipeline( + transformer=transformer, + scheduler=scheduler, + scheduler_state=scheduler_state, + vae=vae, + text_encoder=text_encoder, + patchifier=patchifier, + tokenizer=tokenizer, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + devices_array=devices_array, + mesh=mesh, + config=config, + transformer_state=transformer_state, + transformer_state_shardings=transformer_state_shardings, + ) + + @classmethod + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + def denoising_step( + scheduler, + latents: Array, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + extra_step_kwargs: Dict, + t_eps: float = 1e-6, + stochastic_sampling: bool = False, + ) -> Array: + # Denoise the latents using the scheduler + denoised_latents = scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) + + def retrieve_timesteps( # currently doesn't support custom timesteps + self, + scheduler: FlaxRectifiedFlowMultistepScheduler, + latent_shape, + scheduler_state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + ): + scheduler_state = scheduler.set_timesteps( + state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps + ) + timesteps = scheduler_state.timesteps + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps >= num_inference_steps ): - self.transformer = transformer - self.devices_array = devices_array - self.mesh = mesh - self.config = config - self.p_run_inference = None - self.transformer_state = transformer_state - self.transformer_state_shardings = transformer_state_shardings - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.vae = vae - self.text_encoder = text_encoder - self.patchifier = patchifier - self.tokenizer = tokenizer - self.prompt_enhancer_image_caption_model = prompt_enhancer_image_caption_model - self.prompt_enhancer_image_caption_processor = prompt_enhancer_image_caption_processor - self.prompt_enhancer_llm_model = prompt_enhancer_llm_model - self.prompt_enhancer_llm_tokenizer = prompt_enhancer_llm_tokenizer - self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( - self.vae - ) - self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor) - - @classmethod - def load_scheduler(cls, ckpt_path, config): - if config.sampler == "from_checkpoint" or not config.sampler: - scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ckpt_path) - else: - scheduler = FlaxRectifiedFlowMultistepScheduler( - sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") - ) - scheduler_state = scheduler.create_state() - - return scheduler, scheduler_state - - @classmethod - def load_transformer(cls, config): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - - ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", - "causal_temporal_positioning", "in_channels", "ckpt_path"] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) - - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True - ) - ##load in jax weights checkpoint - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - transformer_state = jax.device_put( - transformer_state, transformer_state_shardings) - get_memory_allocations() - - return transformer, transformer_state, transformer_state_shardings - - @classmethod - def load_vae(cls, ckpt_path): - vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) - return vae - - @classmethod - def load_text_encoder(cls, ckpt_path): - t5_encoder = FlaxT5EncoderModel.from_pretrained(ckpt_path) - return t5_encoder - - @classmethod - def load_tokenizer(cls, config, ckpt_path): - t5_tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, max_length=config.max_sequence_length, use_fast=True - ) - return t5_tokenizer - - @classmethod - def load_prompt_enhancement(cls, config): - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( - config.prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True - ) - prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, torch_dtype="bfloat16", - ) - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( - config.prompt_enhancer_llm_model_name_or_path, - ) - return prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer - - @classmethod - def from_pretrained(cls, config: HyperParameters, enhance_prompt: bool = False): - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - - transformer, transformer_state, transformer_state_shardings = cls.load_transformer( - config) - - # load from pytorch version - models_dir = config.models_dir - ltxv_model_name_or_path = "ltxv-13b-0.9.7-dev.safetensors" - if not os.path.isfile(ltxv_model_name_or_path): - ltxv_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=ltxv_model_name_or_path, - local_dir=models_dir, - repo_type="model", - ) - else: - ltxv_model_path = ltxv_model_name_or_path - - scheduler, scheduler_state = cls.load_scheduler(ltxv_model_path, config) - vae = cls.load_vae(ltxv_model_path) - vae = vae.to(torch.bfloat16) - text_encoder = cls.load_text_encoder( - config.text_encoder_model_name_or_path) - patchifier = SymmetricPatchifier(patch_size=1) - tokenizer = cls.load_tokenizer( - config, config.text_encoder_model_name_or_path) - - enhance_prompt = False - if enhance_prompt: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = cls.load_prompt_enhancement( - config) - else: - prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None - - return LTXVideoPipeline( - transformer=transformer, - scheduler=scheduler, - scheduler_state=scheduler_state, - vae=vae, - text_encoder=text_encoder, - patchifier=patchifier, - tokenizer=tokenizer, - prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model=prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, - devices_array=devices_array, - mesh=mesh, - config=config, - transformer_state=transformer_state, - transformer_state_shardings=transformer_state_shardings - ) - - @classmethod - def _text_preprocessing(self, text): - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - text = text.strip() - return text - - return [process(t) for t in text] - - def denoising_step( - scheduler, - latents: Array, - noise_pred: Array, - current_timestep: Optional[Array], - conditioning_mask: Optional[Array], - t: float, - extra_step_kwargs: Dict, - t_eps: float = 1e-6, - stochastic_sampling: bool = False, - ) -> Array: - # Denoise the latents using the scheduler - denoised_latents = scheduler.step( - noise_pred, - t if current_timestep is None else current_timestep, - latents, - **extra_step_kwargs, - stochastic_sampling=stochastic_sampling, - ) - - if conditioning_mask is None: - return denoised_latents - - tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) - tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) - return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) - - def retrieve_timesteps( #currently doesn't support custom timesteps - self, - scheduler: FlaxRectifiedFlowMultistepScheduler, + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + timesteps = timesteps[skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps] + scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape=latent_shape, state=scheduler_state) + + return scheduler_state + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = text_encoder_max_tokens + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = jnp.array(text_inputs.input_ids) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) #noqa: F841 + + prompt_attention_mask = jnp.array(text_inputs.attention_mask) + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype #noqa: F841 + elif self.transformer is not None: + dtype = self.transformer.dtype #noqa: F841 + else: + dtype = None #noqa: F841 + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_attention_mask = jnp.tile(prompt_attention_mask, (1, num_images_per_prompt)) + prompt_attention_mask = jnp.reshape(prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) + + negative_prompt_embeds = self.text_encoder( + jnp.array(uncond_input.input_ids), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = jnp.tile(negative_prompt_embeds, (1, num_images_per_prompt, 1)) + negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, (1, num_images_per_prompt)) + negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + def prepare_latents( ## this is in pytorch + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items, + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert latents.shape == latent_shape, f"Latents have to be of shape {latent_shape} but are {latents.shape}." + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor((b, f * h * w, c), generator=generator, device=device, dtype=dtype) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + timestep = torch.from_numpy(np.array(timestep)) + latents = timestep * noise + (1 - timestep) * latents + + return latents + + def prepare_conditioning( # removed conditioning_item logic + self, + conditioning_items, + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = True, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + + assert isinstance(self.vae, CausalVideoAutoencoder) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now + causal_fix=True, + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + def __call__( + self, + height: int, + width: int, + num_frames: int, + negative_prompt: str = "", + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + frame_rate: int = 30, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_timesteps: Optional[List[int]] = None, + decode_timestep: Union[List[float], float] = 0.05, + decode_noise_scale: Optional[List[float]] = 0.025, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + num_inference_steps: int = 50, + guidance_scale: Union[float, List[float]] = 4.5, + rescaling_scale: Union[float, List[float]] = 0.7, + stg_scale: Union[float, List[float]] = 1.0, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + cfg_star_rescale: bool = False, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + **kwargs, + ): + enhance_prompt = False + prompt = self.config.prompt + is_video = kwargs.get("is_video", False) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) + import pdb + + pdb.set_trace() + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + latent_shape = ( + batch_size * num_images_per_prompt, + model_config["in_channels"], + latent_num_frames, + latent_height, + latent_width, + ) + scheduler_state = self.retrieve_timesteps( + self.scheduler, latent_shape, - scheduler_state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - timesteps: Optional[List[int]] = None, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - ): - scheduler_state = scheduler.set_timesteps(state=scheduler_state, samples_shape=latent_shape, num_inference_steps=num_inference_steps) - timesteps = scheduler_state.timesteps - if ( - skip_initial_inference_steps < 0 - or skip_final_inference_steps < 0 - or skip_initial_inference_steps + skip_final_inference_steps - >= num_inference_steps - ): - raise ValueError( - "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" - ) - timesteps = timesteps[ - skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + self.scheduler_state, + num_inference_steps, + None, + skip_initial_inference_steps, + skip_final_inference_steps, + ) + + guidance_mapping = [] + + if guidance_timesteps: + for timestep in scheduler_state.timesteps: + indices = [i for i, val in enumerate(guidance_timesteps) if val <= timestep] + guidance_mapping.append(indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)) + + if not isinstance(guidance_scale, list): + guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) + else: + guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(stg_scale, list): + stg_scale = [stg_scale] * len(scheduler_state.timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + if not isinstance(rescaling_scale, list): + rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) + else: + rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] + + guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] + do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) + do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) + do_rescaling = any(x != 1.0 for x in rescaling_scale) + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) + + if not is_list_of_lists: + skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) + else: + new_skip_block_list = [] + for i in range(len(scheduler_state.timesteps)): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + + skip_block_list = new_skip_block_list + + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_masks = [ + self.transformer.create_skip_layer_mask(batch_size, num_conds, num_conds - 1, skip_blocks) + for skip_blocks in skip_block_list ] - scheduler_state = scheduler.set_timesteps(timesteps=timesteps, samples_shape = latent_shape, state=scheduler_state) - - - return scheduler_state - - def encode_prompt( - self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - text_encoder_max_tokens: int = 256, - **kwargs, - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - max_length = ( - text_encoder_max_tokens - ) - if prompt_embeds is None: - assert ( - self.text_encoder is not None - ), "You should provide either prompt_embeds or self.text_encoder should not be None," - - prompt = self._text_preprocessing(prompt) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = jnp.array(text_inputs.input_ids) - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, max_length - 1: -1] - ) - - prompt_attention_mask = jnp.array(text_inputs.attention_mask) - prompt_embeds = self.text_encoder( - text_input_ids, attention_mask=prompt_attention_mask - ) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype - else: - dtype = None - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = jnp.reshape( - prompt_embeds, (bs_embed * num_images_per_prompt, seq_len, -1)) - prompt_attention_mask = jnp.tile( - prompt_attention_mask, (1, num_images_per_prompt)) - prompt_attention_mask = jnp.reshape( - prompt_attention_mask, (bs_embed * num_images_per_prompt, -1)) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = self._text_preprocessing(negative_prompt) - uncond_tokens = uncond_tokens * batch_size - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = jnp.array(uncond_input.attention_mask) - - negative_prompt_embeds = self.text_encoder( - jnp.array(uncond_input.input_ids), - attention_mask=negative_prompt_attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = jnp.tile(negative_prompt_embeds, - (1, num_images_per_prompt, 1) - ) - negative_prompt_embeds = jnp.reshape(negative_prompt_embeds, - (batch_size * num_images_per_prompt, seq_len, -1) - ) - - negative_prompt_attention_mask = jnp.tile(negative_prompt_attention_mask, - (1, num_images_per_prompt) - ) - negative_prompt_attention_mask = jnp.reshape(negative_prompt_attention_mask, - (bs_embed * num_images_per_prompt, -1) - ) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - - return ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) - - def prepare_latents( ## this is in pytorch - self, - latents: torch.Tensor | None, - media_items: torch.Tensor | None, - timestep: float, - latent_shape: torch.Size | Tuple[Any, ...], - dtype: torch.dtype, - device: torch.device, - generator: torch.Generator | List[torch.Generator], - vae_per_channel_normalize: bool = True, - ): - if isinstance(generator, list) and len(generator) != latent_shape[0]: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." - ) - - # Initialize the latents with the given latents or encoded media item, if provided - assert ( - latents is None or media_items is None - ), "Cannot provide both latents and media_items. Please provide only one of the two." - - assert ( - latents is None and media_items is None or timestep < 1.0 - ), "Input media_item or latents are provided, but they will be replaced with noise." - - if media_items is not None: - latents = vae_encode( - media_items, - self.vae, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - if latents is not None: - assert ( - latents.shape == latent_shape - ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." - - # For backward compatibility, generate in the "patchified" shape and rearrange - b, c, f, h, w = latent_shape - noise = randn_tensor( - (b, f * h * w, c), generator=generator, device=device, dtype=dtype - ) - noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) - - # scale the initial noise by the standard deviation required by the scheduler - # noise = noise * self.scheduler.init_noise_sigma !!this doesn;t have - - if latents is None: - latents = noise - else: - # Noise the latents to the required (first) timestep - timestep = torch.from_numpy(np.array(timestep)) - latents = timestep * noise + (1 - timestep) * latents - - return latents - - def prepare_conditioning( - self, - conditioning_items, - init_latents: torch.Tensor, - num_frames: int, - height: int, - width: int, - vae_per_channel_normalize: bool = True, - generator=None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - - assert isinstance(self.vae, CausalVideoAutoencoder) - - if conditioning_items: - batch_size, _, num_latent_frames = init_latents.shape[:3] - - init_conditioning_mask = torch.zeros( - init_latents[:, 0, :, :, :].shape, - dtype=torch.float32, - device=init_latents.device, - ) - - extra_conditioning_latents = [] - extra_conditioning_pixel_coords = [] - extra_conditioning_mask = [] - extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) - - # Process each conditioning item - for conditioning_item in conditioning_items: - conditioning_item = self._resize_conditioning_item( - conditioning_item, height, width - ) - media_item = conditioning_item.media_item - media_frame_number = conditioning_item.media_frame_number - strength = conditioning_item.conditioning_strength - assert media_item.ndim == 5 # (b, c, f, h, w) - b, c, n_frames, h, w = media_item.shape - assert ( - height == h and width == w - ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" - assert n_frames % 8 == 1 - assert ( - media_frame_number >= 0 - and media_frame_number + n_frames <= num_frames - ) - - # Encode the provided conditioning media item - media_item_latents = vae_encode( - media_item.to(dtype=self.vae.dtype, device=self.vae.device), - self.vae, - vae_per_channel_normalize=vae_per_channel_normalize, - ).to(dtype=init_latents.dtype) - - # Handle the different conditioning cases - if media_frame_number == 0: - # Get the target spatial position of the latent conditioning item - media_item_latents, l_x, l_y = self._get_latent_spatial_position( - media_item_latents, - conditioning_item, - height, - width, - strip_latent_border=True, - ) - b, c_l, f_l, h_l, w_l = media_item_latents.shape - - # First frame or sequence - just update the initial noise latents and the mask - init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( - torch.lerp( - init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], - media_item_latents, - strength, - ) - ) - init_conditioning_mask[ - :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l - ] = strength - else: - # Non-first frame or sequence - if n_frames > 1: - # Handle non-first sequence. - # Encoded latents are either fully consumed, or the prefix is handled separately below. - ( - init_latents, - init_conditioning_mask, - media_item_latents, - ) = self._handle_non_first_conditioning_sequence( - init_latents, - init_conditioning_mask, - media_item_latents, - media_frame_number, - strength, - ) - - # Single frame or sequence-prefix latents - if media_item_latents is not None: - noise = randn_tensor( - media_item_latents.shape, - generator=generator, - device=media_item_latents.device, - dtype=media_item_latents.dtype, - ) - - media_item_latents = torch.lerp( - noise, media_item_latents, strength - ) - - # Patchify the extra conditioning latents and calculate their pixel coordinates - media_item_latents, latent_coords = self.patchifier.patchify( - latents=media_item_latents - ) - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.vae, - causal_fix=self.transformer.config.causal_temporal_positioning, - ) - - # Update the frame numbers to match the target frame number - pixel_coords[:, 0] += media_frame_number - extra_conditioning_num_latents += media_item_latents.shape[1] - - conditioning_mask = torch.full( - media_item_latents.shape[:2], - strength, - dtype=torch.float32, - device=init_latents.device, - ) - - extra_conditioning_latents.append(media_item_latents) - extra_conditioning_pixel_coords.append(pixel_coords) - extra_conditioning_mask.append(conditioning_mask) - - # Patchify the updated latents and calculate their pixel coordinates - init_latents, init_latent_coords = self.patchifier.patchify( - latents=init_latents - ) - init_pixel_coords = latent_to_pixel_coords( - init_latent_coords, - self.vae, - # causal_fix=self.transformer.config.causal_temporal_positioning, set to false now - causal_fix=True - ) - - if not conditioning_items: - return init_latents, init_pixel_coords, None, 0 - - init_conditioning_mask, _ = self.patchifier.patchify( - latents=init_conditioning_mask.unsqueeze(1) - ) - init_conditioning_mask = init_conditioning_mask.squeeze(-1) - - if extra_conditioning_latents: - # Stack the extra conditioning latents, pixel coordinates and mask - init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) - init_pixel_coords = torch.cat( - [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 - ) - init_conditioning_mask = torch.cat( - [*extra_conditioning_mask, init_conditioning_mask], dim=1 - ) - - if self.transformer.use_tpu_flash_attention: - # When flash attention is used, keep the original number of tokens by removing - # tokens from the end. - init_latents = init_latents[:, :-extra_conditioning_num_latents] - init_pixel_coords = init_pixel_coords[ - :, :, :-extra_conditioning_num_latents - ] - init_conditioning_mask = init_conditioning_mask[ - :, :-extra_conditioning_num_latents - ] - - return ( - init_latents, - init_pixel_coords, - init_conditioning_mask, - extra_conditioning_num_latents, - ) - - - - def __call__( - self, - height: int, - width: int, - num_frames: int, - negative_prompt: str = "", - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - frame_rate: int = 30, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - guidance_timesteps: Optional[List[int]] = None, - decode_timestep: Union[List[float], float] = 0.05, - decode_noise_scale: Optional[List[float]] = 0.025, - offload_to_cpu: bool = False, - enhance_prompt: bool = False, - text_encoder_max_tokens: int = 256, - num_inference_steps: int = 50, - guidance_scale: Union[float, List[float]] = 4.5, - rescaling_scale: Union[float, List[float]] = 0.7, - stg_scale: Union[float, List[float]] = 1.0, - skip_initial_inference_steps: int = 0, - skip_final_inference_steps: int = 0, - cfg_star_rescale: bool = False, - skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, - **kwargs, - ): - enhance_prompt = False - prompt = self.config.prompt - is_video = kwargs.get("is_video", False) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - vae_per_channel_normalize = kwargs.get( - "vae_per_channel_normalize", True) - import pdb; pdb.set_trace() - - latent_height = height // self.vae_scale_factor - latent_width = width // self.vae_scale_factor - latent_num_frames = num_frames // self.video_scale_factor - if isinstance(self.vae, CausalVideoAutoencoder) and is_video: - latent_num_frames += 1 - base_dir = os.path.dirname(__file__) - config_path = os.path.join( - base_dir, "../../models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) - - latent_shape = ( - batch_size * num_images_per_prompt, - model_config["in_channels"], - latent_num_frames, - latent_height, - latent_width, - ) - scheduler_state = self.retrieve_timesteps(self.scheduler, latent_shape, self.scheduler_state, num_inference_steps, None, skip_initial_inference_steps, skip_final_inference_steps) - - - guidance_mapping = [] - - if guidance_timesteps: - for timestep in scheduler_state.timesteps: - indices = [ - i for i, val in enumerate(guidance_timesteps) if val <= timestep - ] - guidance_mapping.append( - indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) - ) - - if not isinstance(guidance_scale, list): - guidance_scale = [guidance_scale] * len(scheduler_state.timesteps) - else: - guidance_scale = [guidance_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(stg_scale, list): - stg_scale = [stg_scale] * len(scheduler_state.timesteps) - else: - stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - if not isinstance(rescaling_scale, list): - rescaling_scale = [rescaling_scale] * len(scheduler_state.timesteps) - else: - rescaling_scale = [rescaling_scale[guidance_mapping[i]] for i in range(len(scheduler_state.timesteps))] - - guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale] - do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale) - do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale) - do_rescaling = any(x != 1.0 for x in rescaling_scale) - - - num_conds = 1 - if do_classifier_free_guidance: - num_conds += 1 - if do_spatio_temporal_guidance: - num_conds += 1 - - is_list_of_lists = bool(skip_block_list) and isinstance(skip_block_list[0], list) - - if not is_list_of_lists: - skip_block_list = [skip_block_list] * len(scheduler_state.timesteps) - else: - new_skip_block_list = [] - for i in range(len(scheduler_state.timesteps)): - new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) - - skip_block_list = new_skip_block_list - - if do_spatio_temporal_guidance: - if skip_block_list is not None: - skip_layer_masks = [ - self.transformer.create_skip_layer_mask( - batch_size, num_conds, num_conds - 1, skip_blocks - ) - for skip_blocks in skip_block_list - ] - if enhance_prompt: - prompt = generate_cinematic_prompt( - self.prompt_enhancer_image_caption_model, - self.prompt_enhancer_image_caption_processor, - self.prompt_enhancer_llm_model, - self.prompt_enhancer_llm_tokenizer, - prompt, - None, # conditioning items set to None - max_new_tokens=text_encoder_max_tokens, - ) - - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=None, # device set to none - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_max_tokens=text_encoder_max_tokens, - ) - prompt_embeds_batch = prompt_embeds - prompt_attention_mask_batch = prompt_attention_mask - if do_classifier_free_guidance: - prompt_embeds_batch = jnp.concatenate( - [negative_prompt_embeds, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [negative_prompt_attention_mask, prompt_attention_mask], axis=0 - ) - if do_spatio_temporal_guidance: - prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) - prompt_attention_mask_batch = jnp.concatenate( - [ - prompt_attention_mask_batch, - prompt_attention_mask, - ], - axis=0, - ) - latents = self.prepare_latents( - latents=latents, - media_items=None, # set to None - timestep=scheduler_state.timesteps[0], - latent_shape=latent_shape, - dtype=None, - device=None, - generator=generator, - vae_per_channel_normalize=vae_per_channel_normalize, - ) - - latents, pixel_coords, conditioning_mask, num_cond_latents = ( - self.prepare_conditioning( - conditioning_items=None, - init_latents=latents, - num_frames=num_frames, - height=height, - width=width, - vae_per_channel_normalize=vae_per_channel_normalize, - generator=generator, - ) - ) - - extra_step_kwargs = prepare_extra_step_kwargs( - generator=jax.random.PRNGKey(0)) - - pixel_coords = torch.cat([pixel_coords] * num_conds) - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - - noise_cond = jnp.ones( # initialize first round with this! - (1, 1) - ) - segment_ids = None #how is this created? - p_run_inference = functools.partial( - run_inference, - transformer=self.transformer, - config=self.config, - mesh=self.mesh, - fractional_cords=jnp.array( - fractional_coords.to(torch.float32).detach().numpy()), - prompt_embeds=prompt_embeds_batch, - segment_ids=None, - encoder_attention_segment_ids=prompt_attention_mask_batch, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - do_classifier_free_guidance=do_classifier_free_guidance, - num_conds=num_conds, - guidance_scale=guidance_scale, - do_spatio_temporal_guidance=do_spatio_temporal_guidance, - stg_scale=stg_scale, - do_rescaling = do_rescaling, - rescaling_scale = rescaling_scale, - batch_size=batch_size, - skip_layer_masks = skip_layer_masks, - cfg_star_rescale = cfg_star_rescale - ) - - with self.mesh: - latents, scheduler_state = p_run_inference(transformer_state=self.transformer_state, latents=jnp.array(latents.to( - torch.float32).detach().numpy()), timestep=noise_cond, scheduler_state=scheduler_state) - latents = torch.from_numpy(np.array(latents)) - latents = latents[:, num_cond_latents:] - - latents = self.patchifier.unpatchify( - latents=latents, - output_height=latent_height, - output_width=latent_width, - out_channels=model_config["in_channels"] - // math.prod(self.patchifier.patch_size), - ) - if output_type != "latent": - if self.vae.decoder.timestep_conditioning: - noise = torch.randn_like(latents) - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * latents.shape[0] - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [ - decode_noise_scale] * latents.shape[0] - - decode_timestep = torch.tensor( - decode_timestep).to(latents.device) - decode_noise_scale = torch.tensor(decode_noise_scale).to( - latents.device - )[:, None, None, None, None] - latents = ( - latents * (1 - decode_noise_scale) + - noise * decode_noise_scale - ) - else: - decode_timestep = None - image = vae_decode( - latents, - self.vae, - is_video, - vae_per_channel_normalize=kwargs.get( - "vae_per_channel_normalize", True), - timestep=decode_timestep, - ) - image = self.image_processor.postprocess( - image, output_type=output_type) - - else: - image = latents - - # Offload all models - - if not return_dict: - return (image,) - - return image - - - -def transformer_forward_pass( # need to jit this? wan didnt + if enhance_prompt: + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + None, # conditioning items set to None + max_new_tokens=text_encoder_max_tokens, + ) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, # device set to none + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if do_spatio_temporal_guidance: + prompt_embeds_batch = jnp.concatenate([prompt_embeds_batch, prompt_embeds], axis=0) + prompt_attention_mask_batch = jnp.concatenate( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + axis=0, + ) + latents = self.prepare_latents( + latents=latents, + media_items=None, # set to None + timestep=scheduler_state.timesteps[0], + latent_shape=latent_shape, + dtype=None, + device=None, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + latents, pixel_coords, conditioning_mask, num_cond_latents = self.prepare_conditioning( + conditioning_items=None, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + + + pixel_coords = torch.cat([pixel_coords] * num_conds) + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + noise_cond = jnp.ones((1, 1)) # initialize first round with this! + p_run_inference = functools.partial( + run_inference, + transformer=self.transformer, + config=self.config, + mesh=self.mesh, + fractional_cords=jnp.array(fractional_coords.to(torch.float32).detach().numpy()), + prompt_embeds=prompt_embeds_batch, + segment_ids=None, + encoder_attention_segment_ids=prompt_attention_mask_batch, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + do_classifier_free_guidance=do_classifier_free_guidance, + num_conds=num_conds, + guidance_scale=guidance_scale, + do_spatio_temporal_guidance=do_spatio_temporal_guidance, + stg_scale=stg_scale, + do_rescaling=do_rescaling, + rescaling_scale=rescaling_scale, + batch_size=batch_size, + skip_layer_masks=skip_layer_masks, + cfg_star_rescale=cfg_star_rescale, + ) + + with self.mesh: + latents, scheduler_state = p_run_inference( + transformer_state=self.transformer_state, + latents=jnp.array(latents.to(torch.float32).detach().numpy()), + timestep=noise_cond, + scheduler_state=scheduler_state, + ) + latents = torch.from_numpy(np.array(latents)) + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[:, None, None, None, None] + latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), + timestep=decode_timestep, + ) + image = self.image_processor.postprocess(image, output_type=output_type) + + else: + image = latents + + # Offload all models + + if not return_dict: + return (image,) + + return image + + +def transformer_forward_pass( latents, state, noise_cond, @@ -981,220 +785,234 @@ def transformer_forward_pass( # need to jit this? wan didnt encoder_attention_segment_ids, skip_layer_mask, ): - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - skip_layer_mask=skip_layer_mask, - ) - return noise_pred, state + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + skip_layer_mask=skip_layer_mask, + ) + return noise_pred, state def run_inference( - transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale -): - for i, t in enumerate(scheduler_state.timesteps): - current_timestep = t - latent_model_input = ( - jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents - ) - if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): - is_mps = False - if isinstance(current_timestep, float): - dtype = jnp.float32 - else: - dtype = jnp.int32 - - current_timestep = jnp.array( - [current_timestep], - dtype=dtype, - ) - elif current_timestep.ndim == 0: - current_timestep = jnp.expand_dims(current_timestep, axis=0) - - # Broadcast to batch dimension - current_timestep = jnp.broadcast_to( - current_timestep, (latent_model_input.shape[0],1) - ) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): #error out with this line - noise_pred, transformer_state = transformer_forward_pass( - latent_model_input, transformer_state, current_timestep, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids, skip_layer_mask=( - skip_layer_masks[i] - if skip_layer_masks is not None - else None - )) - # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128)) - - - if do_spatio_temporal_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_text = chunks[-2] - noise_pred_text_perturb = chunks[-1] - - if do_classifier_free_guidance: - chunks = jnp.split(noise_pred, num_conds, axis=0) - noise_pred_uncond = chunks[0] - noise_pred_text = chunks[1] - if cfg_star_rescale: - positive_flat = noise_pred_text.reshape(batch_size, -1) - negative_flat = noise_pred_uncond.reshape(batch_size, -1) - dot_product = jnp.sum( - positive_flat * negative_flat, axis=1, keepdims=True - ) - squared_norm = ( - jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 - ) - alpha = dot_product / squared_norm - alpha = alpha.reshape(batch_size, 1, 1) - - noise_pred_uncond = alpha * noise_pred_uncond - noise_pred = noise_pred_uncond + guidance_scale[i] * ( - noise_pred_text - noise_pred_uncond - ) - elif do_spatio_temporal_guidance: - noise_pred = noise_pred_text - - if do_spatio_temporal_guidance: - noise_pred = noise_pred + stg_scale[i] * ( - noise_pred_text - noise_pred_text_perturb - ) - if do_rescaling and stg_scale[i] > 0.0: - noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) - noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) - - factor = noise_pred_text_std / noise_pred_std - factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) - - - noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) - current_timestep = current_timestep[:1] - latents, scheduler_state = scheduler.step( - scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() - - return latents, scheduler_state - -def adain_filter_latent( - latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 + transformer_state, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + num_inference_steps, + scheduler, + segment_ids, + encoder_attention_segment_ids, + scheduler_state, + do_classifier_free_guidance, + num_conds, + guidance_scale, + do_spatio_temporal_guidance, + stg_scale, + do_rescaling, + rescaling_scale, + batch_size, + skip_layer_masks, + cfg_star_rescale, ): - """ - Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on - statistics from a reference latent tensor. - - Args: - latent (torch.Tensor): Input latents to normalize - reference_latent (torch.Tensor): The reference latents providing style statistics. - factor (float): Blending factor between original and transformed latent. - Range: -10.0 to 10.0, Default: 1.0 - - Returns: - torch.Tensor: The transformed latent tensor - """ - result = latents.clone() - - for i in range(latents.size(0)): - for c in range(latents.size(1)): - r_sd, r_mean = torch.std_mean( - reference_latents[i, c], dim=None - ) # index by original dim order - i_sd, i_mean = torch.std_mean(result[i, c], dim=None) - - result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean - - result = torch.lerp(latents, result, factor) - return result + for i, t in enumerate(scheduler_state.timesteps): + current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): + if isinstance(current_timestep, float): + dtype = jnp.float32 + else: + dtype = jnp.int32 + + current_timestep = jnp.array( + [current_timestep], + dtype=dtype, + ) + elif current_timestep.ndim == 0: + current_timestep = jnp.expand_dims(current_timestep, axis=0) + + # Broadcast to batch dimension + current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + + noise_pred, transformer_state = transformer_forward_pass( + latent_model_input, + transformer_state, + current_timestep, + transformer, + fractional_cords, + prompt_embeds, + segment_ids, + encoder_attention_segment_ids, + skip_layer_mask=(skip_layer_masks[i] if skip_layer_masks is not None else None), + ) + + if do_spatio_temporal_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_text = chunks[-2] + noise_pred_text_perturb = chunks[-1] + + if do_classifier_free_guidance: + chunks = jnp.split(noise_pred, num_conds, axis=0) + noise_pred_uncond = chunks[0] + noise_pred_text = chunks[1] + if cfg_star_rescale: + positive_flat = noise_pred_text.reshape(batch_size, -1) + negative_flat = noise_pred_uncond.reshape(batch_size, -1) + dot_product = jnp.sum(positive_flat * negative_flat, axis=1, keepdims=True) + squared_norm = jnp.sum(negative_flat**2, axis=1, keepdims=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.reshape(batch_size, 1, 1) + + noise_pred_uncond = alpha * noise_pred_uncond + noise_pred = noise_pred_uncond + guidance_scale[i] * (noise_pred_text - noise_pred_uncond) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * (noise_pred_text - noise_pred_text_perturb) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = jnp.std(noise_pred_text.reshape(batch_size, -1), axis=1, keepdims=True) + noise_pred_std = jnp.std(noise_pred.reshape(batch_size, -1), axis=1, keepdims=True) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) + current_timestep = current_timestep[:1] + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + return latents, scheduler_state + + +def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + class LTXMultiScalePipeline: - - @classmethod - def load_latent_upsampler(cls, config): - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path - - if spatial_upscaler_model_name_or_path and not os.path.isfile( - spatial_upscaler_model_name_or_path - ): - spatial_upscaler_model_path = hf_hub_download( - repo_id="Lightricks/LTX-Video", - filename=spatial_upscaler_model_name_or_path, - local_dir= "/mnt/disks/diffusionproj", - repo_type="model", - ) - else: - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path - if not config.spatial_upscaler_model_path: - raise ValueError( - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" - ) - latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) - latent_upsampler.eval() - return latent_upsampler - - - def _upsample_latents( - self, latest_upsampler: LatentUpsampler, latents: torch.Tensor - ): - assert latents.device == latest_upsampler.device - - latents = un_normalize_latents( - latents, self.vae, vae_per_channel_normalize=True - ) - upsampled_latents = latest_upsampler(latents) - upsampled_latents = normalize_latents( - upsampled_latents, self.vae, vae_per_channel_normalize=True - ) - return upsampled_latents - - def __init__( - self, video_pipeline: LTXVideoPipeline - ): - self.video_pipeline = video_pipeline - self.vae = video_pipeline.vae - - def __call__( - self, - height, - width, - num_frames, - output_type, - generator, - config - ) -> Any: - - latent_upsampler = self.load_latent_upsampler(config) - original_output_type = output_type - output_type = 'latent' - result = self.video_pipeline(height=height, width=width, num_frames=num_frames, - is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"], - num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_block_list=config.first_pass["skip_block_list"]) - latents = result - upsampled_latents = self._upsample_latents(latent_upsampler, latents) - upsampled_latents = adain_filter_latent( - latents=upsampled_latents, reference_latents=latents - ) - - latents = upsampled_latents - output_type = original_output_type - - - result = self.video_pipeline(height=height*2, width=width*2, num_frames=num_frames, - is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"], - num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_block_list=config.second_pass["skip_block_list"]) - - if original_output_type != "latent": - num_frames = result.shape[2] - videos = rearrange(result, "b c f h w -> (b f) c h w") - - videos = F.interpolate( - videos, - size=(height, width), - mode="bilinear", - align_corners=False, - ) - videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) - result = videos - - return result \ No newline at end of file + + @classmethod + def load_latent_upsampler(cls, config): + spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path + + if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=config.models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + if not config.spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path) + latent_upsampler.eval() + return latent_upsampler + + def _upsample_latents(self, latest_upsampler: LatentUpsampler, latents: torch.Tensor): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents(latents, self.vae, vae_per_channel_normalize=True) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents(upsampled_latents, self.vae, vae_per_channel_normalize=True) + return upsampled_latents + + def __init__(self, video_pipeline: LTXVideoPipeline): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + + def __call__(self, height, width, num_frames, output_type, generator, config) -> Any: + + latent_upsampler = self.load_latent_upsampler(config) + original_output_type = output_type + output_type = "latent" + result = self.video_pipeline( + height=height, + width=width, + num_frames=num_frames, + is_video=True, + output_type=output_type, + generator=generator, + guidance_scale=config.first_pass["guidance_scale"], + stg_scale=config.first_pass["stg_scale"], + rescaling_scale=config.first_pass["rescaling_scale"], + skip_initial_inference_steps=config.first_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.first_pass["skip_final_inference_steps"], + num_inference_steps=config.first_pass["num_inference_steps"], + guidance_timesteps=config.first_pass["guidance_timesteps"], + cfg_star_rescale=config.first_pass["cfg_star_rescale"], + skip_block_list=config.first_pass["skip_block_list"], + ) + latents = result + upsampled_latents = self._upsample_latents(latent_upsampler, latents) + upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents) + + latents = upsampled_latents + output_type = original_output_type + + result = self.video_pipeline( + height=height * 2, + width=width * 2, + num_frames=num_frames, + is_video=True, + output_type=output_type, + latents=latents, + generator=generator, + guidance_scale=config.second_pass["guidance_scale"], + stg_scale=config.second_pass["stg_scale"], + rescaling_scale=config.second_pass["rescaling_scale"], + skip_initial_inference_steps=config.second_pass["skip_initial_inference_steps"], + skip_final_inference_steps=config.second_pass["skip_final_inference_steps"], + num_inference_steps=config.second_pass["num_inference_steps"], + guidance_timesteps=config.second_pass["guidance_timesteps"], + cfg_star_rescale=config.second_pass["cfg_star_rescale"], + skip_block_list=config.second_pass["skip_block_list"], + ) + + if original_output_type != "latent": + num_frames = result.shape[2] + videos = rearrange(result, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result = videos + + return result diff --git a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py index 1624b81c4..b550aeea3 100644 --- a/src/maxdiffusion/schedulers/scheduling_rectified_flow.py +++ b/src/maxdiffusion/schedulers/scheduling_rectified_flow.py @@ -24,42 +24,47 @@ import jax.numpy as jnp import json from maxdiffusion.configuration_utils import ConfigMixin, register_to_config -from maxdiffusion.utils import is_scipy_available from maxdiffusion.schedulers.scheduling_utils_flax import ( CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, ) -def linear_quadratic_schedule_jax(num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None) -> jnp.ndarray: - if num_steps == 1: - return jnp.array([1.0], dtype=jnp.float32) - if linear_steps is None: - linear_steps = num_steps // 2 - linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps +def linear_quadratic_schedule_jax( + num_steps: int, threshold_noise: float = 0.025, linear_steps: Optional[int] = None +) -> jnp.ndarray: + if num_steps == 1: + return jnp.array([1.0], dtype=jnp.float32) + if linear_steps is None: + linear_steps = num_steps // 2 + + linear_sigma_schedule = jnp.arange(linear_steps) * threshold_noise / linear_steps + + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_indices = jnp.arange(linear_steps, num_steps) + quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const - threshold_noise_step_diff = linear_steps - threshold_noise * num_steps - quadratic_steps = num_steps - linear_steps - quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) - linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) - const = quadratic_coef * (linear_steps**2) - quadratic_indices = jnp.arange(linear_steps, num_steps) - quadratic_sigma_schedule = quadratic_coef * (quadratic_indices**2) + linear_coef * quadratic_indices + const + sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) + sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) + sigma_schedule = 1.0 - sigma_schedule + return sigma_schedule[:-1].astype(jnp.float32) - sigma_schedule = jnp.concatenate([linear_sigma_schedule, quadratic_sigma_schedule]) - sigma_schedule = jnp.concatenate([sigma_schedule, jnp.array([1.0])]) - sigma_schedule = 1.0 - sigma_schedule - return sigma_schedule[:-1].astype(jnp.float32) def time_shift_jax(mu: float, sigma: float, t: jnp.ndarray) -> jnp.ndarray: - mu_f = jnp.array(mu, dtype=jnp.float32) - sigma_f = jnp.array(sigma, dtype=jnp.float32) - return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) - + mu_f = jnp.array(mu, dtype=jnp.float32) + sigma_f = jnp.array(sigma, dtype=jnp.float32) + return jnp.exp(mu_f) / (jnp.exp(mu_f) + (1 / t - 1) ** sigma_f) + + def _prod_jax(iterable): - return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) - + return jnp.prod(jnp.array(iterable, dtype=jnp.float32)) + + def get_normal_shift_jax( n_tokens: int, min_tokens: int = 1024, @@ -67,72 +72,71 @@ def get_normal_shift_jax( min_shift: float = 0.95, max_shift: float = 2.05, ) -> float: - m = (max_shift - min_shift) / (max_tokens - min_tokens) - b = min_shift - m * min_tokens - return m * n_tokens + b -def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: - """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" - return x[(...,) + (None,) * (target_dims - x.ndim)] + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b +def append_dims_jax(x: jnp.ndarray, target_dims: int) -> jnp.ndarray: + """Appends singleton dimensions to the end of a tensor until it reaches `target_dims`.""" + return x[(...,) + (None,) * (target_dims - x.ndim)] + def strech_shifts_to_terminal_jax(shifts: jnp.ndarray, terminal: float = 0.1) -> jnp.ndarray: - if shifts.size == 0: - raise ValueError("The 'shifts' tensor must not be empty.") - if terminal <= 0 or terminal >= 1: - raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + if shifts.size == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + one_minus_z = 1.0 - shifts + # Using shifts[-1] for the last element + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched_shifts = 1.0 - (one_minus_z / scale_factor) - one_minus_z = 1.0 - shifts - # Using shifts[-1] for the last element - scale_factor = one_minus_z[-1] / (1.0 - terminal) - stretched_shifts = 1.0 - (one_minus_z / scale_factor) + return stretched_shifts - return stretched_shifts def sd3_resolution_dependent_timestep_shift_jax( samples_shape: Tuple[int, ...], timesteps: jnp.ndarray, target_shift_terminal: Optional[float] = None, ) -> jnp.ndarray: - if len(samples_shape) == 3: - _, m, _ = samples_shape - elif len(samples_shape) in [4, 5]: - m = _prod_jax(samples_shape[2:]) - else: - raise ValueError( - "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" - ) + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + + shift = get_normal_shift_jax(int(m)) + time_shifts = time_shift_jax(shift, 1.0, timesteps) - shift = get_normal_shift_jax(int(m)) - time_shifts = time_shift_jax(shift, 1.0, timesteps) + if target_shift_terminal is not None: + time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) + return time_shifts - if target_shift_terminal is not None: - time_shifts = strech_shifts_to_terminal_jax(time_shifts, target_shift_terminal) - return time_shifts - def simple_diffusion_resolution_dependent_timestep_shift_jax( samples_shape: Tuple[int, ...], timesteps: jnp.ndarray, n: int = 32 * 32, ) -> jnp.ndarray: - if len(samples_shape) == 3: - _, m, _ = samples_shape - elif len(samples_shape) in [4, 5]: - m = _prod_jax(samples_shape[2:]) - else: - raise ValueError( - "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" - ) - # Ensure m and n are float32 for calculations - m_f = jnp.array(m, dtype=jnp.float32) - n_f = jnp.array(n, dtype=jnp.float32) + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = _prod_jax(samples_shape[2:]) + else: + raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)") + # Ensure m and n are float32 for calculations + m_f = jnp.array(m, dtype=jnp.float32) + n_f = jnp.array(n, dtype=jnp.float32) + + snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability + shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability + shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) - snr = (timesteps / (1 - timesteps)) ** 2 # Add epsilon for numerical stability - shift_snr = jnp.log(snr) + 2 * jnp.log(m_f / n_f) # Add epsilon for numerical stability - shifted_timesteps = jax.nn.sigmoid(0.5 * shift_snr) + return shifted_timesteps - return shifted_timesteps @flax.struct.dataclass class RectifiedFlowSchedulerState: @@ -145,28 +149,21 @@ class RectifiedFlowSchedulerState: num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None sigmas: Optional[jnp.ndarray] = None - - - @classmethod - def create( #need to change this! - cls, - common_state: CommonSchedulerState, - init_noise_sigma: float - ): + def create(cls, common_state: CommonSchedulerState, init_noise_sigma: float): return cls( - common = common_state, - init_noise_sigma = init_noise_sigma, - num_inference_steps = None, - timesteps = None, - sigmas = None, + common=common_state, + init_noise_sigma=init_noise_sigma, + num_inference_steps=None, + timesteps=None, + sigmas=None, ) @dataclass class FlaxRectifiedFlowSchedulerOutput(FlaxSchedulerOutput): - state: RectifiedFlowSchedulerState + state: RectifiedFlowSchedulerState class FlaxRectifiedFlowMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): @@ -195,163 +192,136 @@ def __init__( dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype - def create_state(self, common: Optional[CommonSchedulerState] = None) -> RectifiedFlowSchedulerState: if common is None: common = CommonSchedulerState.create(self) init_noise_sigma = 1.0 - return RectifiedFlowSchedulerState.create(common_state = common, init_noise_sigma=init_noise_sigma) - - - def get_initial_timesteps_jax( - self, num_timesteps: int, shift: Optional[float] = None - ) -> jnp.ndarray: - if self.config.sampler == "Uniform": - return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) - elif self.config.sampler == "LinearQuadratic": - return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) - elif self.config.sampler == "Constant": - assert shift is not None, "Shift must be provided for constant time shift sampler." - return time_shift_jax( - shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) - ).astype(self.dtype) - else: - # This should be caught by __init__ but for safety - raise ValueError(f"Sampler {self.config.sampler} is not supported.") + return RectifiedFlowSchedulerState.create(common_state=common, init_noise_sigma=init_noise_sigma) + + def get_initial_timesteps_jax(self, num_timesteps: int, shift: Optional[float] = None) -> jnp.ndarray: + if self.config.sampler == "Uniform": + return jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype) + elif self.config.sampler == "LinearQuadratic": + return linear_quadratic_schedule_jax(num_timesteps).astype(self.dtype) + elif self.config.sampler == "Constant": + assert shift is not None, "Shift must be provided for constant time shift sampler." + return time_shift_jax(shift, 1.0, jnp.linspace(1.0, 1.0 / num_timesteps, num_timesteps, dtype=self.dtype)).astype( + self.dtype + ) + else: + raise ValueError(f"Sampler {self.config.sampler} is not supported.") def shift_timesteps_jax(self, samples_shape: Tuple[int, ...], timesteps: jnp.ndarray) -> jnp.ndarray: - if self.config.shifting == "SD3": - return sd3_resolution_dependent_timestep_shift_jax( - samples_shape, timesteps, self.config.target_shift_terminal - ) - elif self.config.shifting == "SimpleDiffusion": - return simple_diffusion_resolution_dependent_timestep_shift_jax( - samples_shape, timesteps, self.config.base_resolution - ) - return timesteps - + if self.config.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.target_shift_terminal) + elif self.config.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift_jax(samples_shape, timesteps, self.config.base_resolution) + return timesteps + def from_pretrained_jax(pretrained_model_path: Union[str, os.PathLike]): pretrained_model_path = Path(pretrained_model_path) config = None if pretrained_model_path.is_file(): - with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - configs = json.loads(metadata['config']) - config = configs["scheduler"] - + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + elif pretrained_model_path.is_dir(): - diffusers_noise_scheduler_config_path = ( - pretrained_model_path / "scheduler" / "scheduler_config.json" - ) - - if not diffusers_noise_scheduler_config_path.is_file(): - raise FileNotFoundError( - f"Scheduler config not found at {diffusers_noise_scheduler_config_path}" - ) - - with open(diffusers_noise_scheduler_config_path, "r") as f: - scheduler_config = json.load(f) - config = scheduler_config + diffusers_noise_scheduler_config_path = pretrained_model_path / "scheduler" / "scheduler_config.json" + + if not diffusers_noise_scheduler_config_path.is_file(): + raise FileNotFoundError(f"Scheduler config not found at {diffusers_noise_scheduler_config_path}") + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + config = scheduler_config return FlaxRectifiedFlowMultistepScheduler.from_config(config) - + def set_timesteps( - self, - state: RectifiedFlowSchedulerState, - num_inference_steps: Optional[int] = None, - samples_shape: Optional[Tuple[int, ...]] = None, - timesteps: Optional[jnp.ndarray] = None, - device: Optional[str] = None, - ) -> RectifiedFlowSchedulerState: - if timesteps is not None and num_inference_steps is not None: - raise ValueError( - "You cannot provide both `timesteps` and `num_inference_steps`." - ) - - # Determine the number of inference steps if not provided - if num_inference_steps is None and timesteps is None: - raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") - - if timesteps is None: - num_inference_steps = jnp.minimum( - self.config.num_train_timesteps, num_inference_steps - ) - timesteps = self.get_initial_timesteps_jax( - num_inference_steps, shift=self.config.shift - ).astype(self.dtype) - - # Apply shifting if samples_shape is provided and shifting is configured - if samples_shape is not None: - timesteps = self.shift_timesteps_jax(samples_shape, timesteps) - else: - timesteps = jnp.asarray(timesteps, dtype=self.dtype) - num_inference_steps = len(timesteps) - - return state.replace( - timesteps=timesteps, - num_inference_steps=num_inference_steps, - sigmas=timesteps, # sigmas are the same as timesteps in RF - ) - + self, + state: RectifiedFlowSchedulerState, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[Tuple[int, ...]] = None, + timesteps: Optional[jnp.ndarray] = None, + device: Optional[str] = None, + ) -> RectifiedFlowSchedulerState: + if timesteps is not None and num_inference_steps is not None: + raise ValueError("You cannot provide both `timesteps` and `num_inference_steps`.") + + # Determine the number of inference steps if not provided + if num_inference_steps is None and timesteps is None: + raise ValueError("Either `num_inference_steps` or `timesteps` must be provided.") + + if timesteps is None: + num_inference_steps = jnp.minimum(self.config.num_train_timesteps, num_inference_steps) + timesteps = self.get_initial_timesteps_jax(num_inference_steps, shift=self.config.shift).astype(self.dtype) + + # Apply shifting if samples_shape is provided and shifting is configured + if samples_shape is not None: + timesteps = self.shift_timesteps_jax(samples_shape, timesteps) + else: + timesteps = jnp.asarray(timesteps, dtype=self.dtype) + num_inference_steps = len(timesteps) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + sigmas=timesteps, # sigmas are the same as timesteps in RF + ) + def scale_model_input( - self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - # Rectified Flow scheduler typically doesn't scale model input, returns as is. - return sample - - + self, state: RectifiedFlowSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + # Rectified Flow scheduler typically doesn't scale model input, returns as is. + return sample + def step( - self, - state: RectifiedFlowSchedulerState, - model_output: jnp.ndarray, - timestep: jnp.ndarray, # Can be global or per-token, but for RF it's typically global. - sample: jnp.ndarray, - return_dict: bool = True, - stochastic_sampling: bool = False, - generator: Optional[jax.random.PRNGKey] = None, - ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - t_eps = 1e-6 # Small epsilon for numerical issues - - timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) - - - if timestep.ndim == 0: - idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side='right') - current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] - lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), - timesteps_padded[current_t_idx + 1], - 0.0) - dt = timestep - lower_timestep - else: - current_t_indices = jnp.searchsorted(state.timesteps, timestep, side='right') # timesteps is decreasing - current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search - lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) - lower_timestep = timesteps_padded[lower_timestep_indices] - dt = timestep - lower_timestep - dt = append_dims_jax(dt, sample.ndim) - - - # Compute previous sample - if stochastic_sampling: - if generator is None: - raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") - broadcastable_timestep = append_dims_jax(timestep, sample.ndim) - - x0 = sample - broadcastable_timestep * model_output - next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep - - noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) - prev_sample = self.add_noise(state.common, x0, noise, next_timestep) - else: - prev_sample = sample - dt * model_output - - - if not return_dict: - return (prev_sample, state) - - return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) + self, + state: RectifiedFlowSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + return_dict: bool = True, + stochastic_sampling: bool = False, + generator: Optional[jax.random.PRNGKey] = None, + ) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + if state.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + t_eps = 1e-6 # Small epsilon for numerical issues + + timesteps_padded = jnp.concatenate([state.timesteps, jnp.array([0.0], dtype=self.dtype)]) + + if timestep.ndim == 0: + idx = jnp.searchsorted(timesteps_padded, timestep - t_eps, side="right") #noqa: F841 + current_t_idx = jnp.where(state.timesteps == timestep, size=1, fill_value=len(state.timesteps))[0][0] + lower_timestep = jnp.where(current_t_idx + 1 < len(timesteps_padded), timesteps_padded[current_t_idx + 1], 0.0) + dt = timestep - lower_timestep + else: + current_t_indices = jnp.searchsorted(state.timesteps, timestep, side="right") # timesteps is decreasing + current_t_indices = jnp.where(current_t_indices > 0, current_t_indices - 1, 0) # adjust for right side search + lower_timestep_indices = jnp.minimum(current_t_indices + 1, len(timesteps_padded) - 1) + lower_timestep = timesteps_padded[lower_timestep_indices] + dt = timestep - lower_timestep + dt = append_dims_jax(dt, sample.ndim) + + # Compute previous sample + if stochastic_sampling: + if generator is None: + raise ValueError("`generator` PRNGKey must be provided for stochastic sampling.") + broadcastable_timestep = append_dims_jax(timestep, sample.ndim) + + x0 = sample - broadcastable_timestep * model_output + next_timestep = timestep - dt.squeeze((1,) * (dt.ndim - timestep.ndim)) # Remove extra dims from dt to match timestep + + noise = jax.random.normal(generator, sample.shape, dtype=self.dtype) + prev_sample = self.add_noise(state.common, x0, noise, next_timestep) + else: + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample, state) + + return FlaxRectifiedFlowSchedulerOutput(prev_sample=prev_sample, state=state) From f63a6fab9ce6c25ef9dd46ed441f02bb62608de7 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 18:59:14 +0000 Subject: [PATCH 33/34] remove init --- src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py deleted file mode 100644 index e69de29bb..000000000 From b3874f565884f111af819a6ccd1747e11239483f Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 16 Jul 2025 19:03:31 +0000 Subject: [PATCH 34/34] new empty folders --- .../ltx_video/models/autoencoders/__init__.py | 16 ++++++++++++++++ .../models/ltx_video/utils/__init__.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/utils/__init__.py diff --git a/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/models/autoencoders/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/__init__.py b/src/maxdiffusion/models/ltx_video/utils/__init__.py new file mode 100644 index 000000000..cb4a6b9ce --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main \ No newline at end of file