From 377619088bcd5b04fcab5413b3dfd740c32eb27a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 19:05:46 +0000 Subject: [PATCH 01/25] set up files for ltxvid --- src/maxdiffusion/__init__.py | 721 +++++++++++++------------ src/maxdiffusion/configs/ltx_video.yml | 50 ++ src/maxdiffusion/generate_ltx_video.py | 73 +++ src/maxdiffusion/models/__init__.py | 5 +- 4 files changed, 490 insertions(+), 359 deletions(-) create mode 100644 src/maxdiffusion/configs/ltx_video.yml create mode 100644 src/maxdiffusion/generate_ltx_video.py diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed682..677d64e4e 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,438 +65,447 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [ + name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [ + name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend( + ["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = [ + "FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = [ + "FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = [ + "Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, - ) + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 000000000..ac333d329 --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,50 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 000000000..81c832c3c --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,73 @@ +from absl import app +from typing import Sequence +import jax +import json +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, +) +from jax.sharding import Mesh, PartitionSpec as P + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", + fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + + +def run(config): + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) + + # load in model config + config_path = os.path.join( + base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + + transformer = Transformer3DModel( + **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights( + key, batch_size, text_tokens, num_tokens, features, eval_only=False) + + key, split_key = jax.random.split(key) + weights_init_fn = functools.partial( + transformer.init_weights, + split_key, + batch_size, + text_tokens, + num_tokens, + features, + eval_only=False + ) + + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24e..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -13,9 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING - -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available - +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,6 +30,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys From 13656fb457723f52e935e8afab69a983d5cfd68a Mon Sep 17 00:00:00 2001 From: Serenagu525 Date: Thu, 26 Jun 2025 20:32:05 +0000 Subject: [PATCH 02/25] ltx-video-transformer-setup --- src/maxdiffusion/configs/ltx_video.yml | 15 + src/maxdiffusion/generate_ltx_video.py | 29 +- src/maxdiffusion/models/__init__.py | 17 +- src/maxdiffusion/models/ltx_video/__init__.py | 0 .../models/ltx_video/gradient_checkpoint.py | 70 ++ src/maxdiffusion/models/ltx_video/linear.py | 111 ++ .../models/ltx_video/repeatable_layer.py | 105 ++ .../models/ltx_video/transformers/__init__.py | 0 .../ltx_video/transformers/activations.py | 176 ++++ .../models/ltx_video/transformers/adaln.py | 201 ++++ .../ltx_video/transformers/attention.py | 945 ++++++++++++++++++ .../transformers/caption_projection.py | 40 + .../ltx_video/transformers/transformer3d.py | 322 ++++++ .../ltx_video/xora_v1.2-13B-balanced-128.json | 24 + 14 files changed, 2036 insertions(+), 19 deletions(-) create mode 100644 src/maxdiffusion/models/ltx_video/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/gradient_checkpoint.py create mode 100644 src/maxdiffusion/models/ltx_video/linear.py create mode 100644 src/maxdiffusion/models/ltx_video/repeatable_layer.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/__init__.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/activations.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/adaln.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/attention.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/caption_projection.py create mode 100644 src/maxdiffusion/models/ltx_video/transformers/transformer3d.py create mode 100644 src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index ac333d329..954922521 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,3 +1,18 @@ +# Copyright 2025 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + #hardware hardware: 'tpu' skip_jax_distributed_system: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 81c832c3c..6d96aa8c2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,20 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + from absl import app from typing import Sequence import jax @@ -50,17 +67,7 @@ def run(config): text_tokens, num_tokens, features, - eval_only=False - ) - - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, + eval_only=True ) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 96a6f1286..20c27ab20 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,14 +25,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 000000000..f32cc9459 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 000000000..fd92c695d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,111 @@ +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, + jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features):], + kernel_shape[-len(features):], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 000000000..882f21ace --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,105 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = { + nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( + self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn( + self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 000000000..3e1fd6d6e --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,176 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from diffusers.utils.deprecation_utils import deprecate + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + # Mish is not in JAX by default + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError( + f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 000000000..374af6acc --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,201 @@ +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint( + emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate( + [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint( + sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + + """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb( + timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 000000000..4ade671c7 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,945 @@ +from functools import partial +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto + +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in [ + "single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name( + hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", + "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * \ + (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm( + hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError( + f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", + "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any( + device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print( + "Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get( + "scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", + "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint( + hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint( + encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape( + hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states( + encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", + "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones( + (batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", + "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes( + hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + \ + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1]( + hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes( + hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape( + skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad( + attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), + kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError( + f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning( + self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype( + jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * \ + jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) + * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError( + f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)( + hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 000000000..dff8b8c62 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,40 @@ +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 000000000..4368c35fb --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,322 @@ +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning( + scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm( + epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal( + split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal( + split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal( + split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal( + split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection( + encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, + :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", + "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, + axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 000000000..02f13b15a --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,24 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000 +} \ No newline at end of file From 7bed4f99d1b8c19eb1ee622ee895f1ca31b1b870 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 21:12:17 +0000 Subject: [PATCH 03/25] formatting --- src/maxdiffusion/__init__.py | 723 ++++--- src/maxdiffusion/generate_ltx_video.py | 60 +- src/maxdiffusion/models/__init__.py | 17 +- .../models/ltx_video/gradient_checkpoint.py | 102 +- src/maxdiffusion/models/ltx_video/linear.py | 170 +- .../models/ltx_video/repeatable_layer.py | 129 +- .../ltx_video/transformers/activations.py | 275 ++- .../models/ltx_video/transformers/adaln.py | 328 ++-- .../ltx_video/transformers/attention.py | 1696 ++++++++--------- .../transformers/caption_projection.py | 60 +- .../ltx_video/transformers/transformer3d.py | 585 +++--- 11 files changed, 2023 insertions(+), 2122 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 677d64e4e..42e50d775 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -65,447 +65,440 @@ } try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_onnx_objects # noqa F403 + from .utils import dummy_onnx_objects # noqa F403 - _import_structure["utils.dummy_onnx_objects"] = [ - name for name in dir(dummy_onnx_objects) if not name.startswith("_")] + _import_structure["utils.dummy_onnx_objects"] = [name for name in dir(dummy_onnx_objects) if not name.startswith("_")] else: - _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() + if not is_torch_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_pt_objects # noqa F403 + from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [ - name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) - _import_structure["optimization"] = [ - "get_constant_schedule", - "get_constant_schedule_with_warmup", - "get_cosine_schedule_with_warmup", - "get_cosine_with_hard_restarts_schedule_with_warmup", - "get_linear_schedule_with_warmup", - "get_polynomial_decay_schedule_with_warmup", - "get_scheduler", - ] - - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) - _import_structure["training_utils"] = ["EMAModel"] + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ] + ) + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ] + ) + _import_structure["schedulers"].extend( + [ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_scipy_objects # noqa F403 + from .utils import dummy_torch_and_scipy_objects # noqa F403 - _import_structure["utils.dummy_torch_and_scipy_objects"] = [ - name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) try: - if not (is_torch_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 + from .utils import dummy_torch_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) try: - if not (is_torch_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - ["StableDiffusionKDiffusionPipeline"]) + _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"]) try: - if not (is_torch_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_librosa_objects # noqa F403 + from .utils import dummy_torch_and_librosa_objects # noqa F403 - _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() + if not is_flax_available(): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_objects # noqa F403 + from .utils import dummy_flax_objects # noqa F403 - _import_structure["utils.dummy_flax_objects"] = [ - name for name in dir(dummy_flax_objects) if not name.startswith("_")] + _import_structure["utils.dummy_flax_objects"] = [name for name in dir(dummy_flax_objects) if not name.startswith("_")] else: - _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = [ - "FlaxUNet2DConditionModel"] - _import_structure["models.flux.transformers.transformer_flux_flax"] = [ - "FluxTransformer2DModel"] - _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] - _import_structure["models.ltx_video.transformers.transformer3d"] = [ - "Transformer3DModel"] - _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_flax_and_transformers_objects # noqa F403 + from .utils import dummy_flax_and_transformers_objects # noqa F403 - _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_note_seq_objects # noqa F403 + from .utils import dummy_note_seq_objects # noqa F403 - _import_structure["utils.dummy_note_seq_objects"] = [ - name for name in dir(dummy_note_seq_objects) if not name.startswith("_") - ] + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] else: - _import_structure["pipelines"].extend(["MidiProcessor"]) + _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .configuration_utils import ConfigMixin - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 - else: - from .pipelines import OnnxRuntimeModel - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 - else: - import generate - import max_utils - import pyconfig - import input_pipeline - import transformers - from .models.controlnet_flax import FlaxControlNetModel - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .models.ltx_video.transformers.transformer3d import Transformer3DModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxEulerDiscreteScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - try: - if not (is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipelines import ( - FlaxStableDiffusionControlNetPipeline, - FlaxStableDiffusionXLControlNetPipeline, - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - FlaxStableDiffusionXLPipeline, - ) - - try: - if not (is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_note_seq_objects import * # noqa F403 - else: - from .pipelines import MidiProcessor + from .configuration_utils import ConfigMixin -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + import generate + import max_utils + import pyconfig + import input_pipeline + import transformers + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, ) + + try: + if not (is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionXLControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6d96aa8c2..d05203f5c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -14,67 +14,55 @@ limitations under the License. """ - from absl import app from typing import Sequence import jax import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os -import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, - setup_initial_state, ) -from jax.sharding import Mesh, PartitionSpec as P def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): - print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) - print("fractional_coords.shape: ", - fractional_coords.shape, fractional_coords.dtype) - print("latents.shape: ", latents.shape, latents.dtype) - print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 + base_dir = os.path.dirname(__file__) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + # load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 - base_dir = os.path.dirname(__file__) + transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) - # load in model config - config_path = os.path.join( - base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") - with open(config_path, "r") as f: - model_config = json.load(f) + key, split_key = jax.random.split(key) - transformer = Transformer3DModel( - **model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights( - key, batch_size, text_tokens, num_tokens, features, eval_only=False) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( - transformer.init_weights, - split_key, - batch_size, - text_tokens, - num_tokens, - features, - eval_only=True - ) + weights_init_fn = functools.partial( + transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + ) def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) - run(pyconfig.config) + pyconfig.initialize(argv) + run(pyconfig.config) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 20c27ab20..96a6f1286 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -25,15 +25,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL - from .lora import * - from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from .ltx_video.transformers.transformer3d import Transformer3DModel + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + from .lora import * + from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .ltx_video.transformers.transformer3d import Transformer3DModel else: - import sys + import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index f32cc9459..ef8c530ba 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -8,63 +8,63 @@ class GradientCheckpointType(Enum): - """ - Defines the type of the gradient checkpoint we will have + """ + Defines the type of the gradient checkpoint we will have - NONE - means no gradient checkpoint - FULL - means full gradient checkpoint, wherever possible (minimum memory usage) - MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, - except for ones that involve batch dimension - that means that all attention and projection - layers will have gradient checkpoint, but not the backward with respect to the parameters - """ + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ - NONE = auto() - FULL = auto() - MATMUL_WITHOUT_BATCH = auto() + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() - @classmethod - def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": - """ - Constructs the gradient checkpoint type from a string + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string - Args: - s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. - Returns: - GradientCheckpointType: The policy that corresponds to the string - """ - if s is None: - s = "none" - return GradientCheckpointType[s.upper()] + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] - def to_jax_policy(self): - """ - Converts the gradient checkpoint type to a jax policy - """ - match self: - case GradientCheckpointType.NONE: - return SKIP_GRADIENT_CHECKPOINT_KEY - case GradientCheckpointType.FULL: - return None - case GradientCheckpointType.MATMUL_WITHOUT_BATCH: - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nn.Module) -> nn.Module: - """ - Applies a gradient checkpoint policy to a module - if no policy is needed, it will return the module as is + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is - Args: - module (nn.Module): the module to apply the policy to + Args: + module (nn.Module): the module to apply the policy to - Returns: - nn.Module: the module with the policy applied - """ - policy = self.to_jax_policy() - if policy == SKIP_GRADIENT_CHECKPOINT_KEY: - return module - return nn.remat( # pylint: disable=invalid-name - module, - prevent_cse=False, - policy=policy, - ) + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index fd92c695d..31b21cdd9 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -13,99 +13,97 @@ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) -NdInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] -KernelInitializer = Callable[[jax.random.PRNGKey, Shape, - jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] class DenseGeneral(nn.Module): - """A linear transformation with flexible axes. - - Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - weight_dtype: the dtype of the weights (default: float32). - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - use_bias: whether to add bias in linear transformation. - bias_norm: whether to add normalization before adding bias. - quant: quantization config, defaults to None implying no quantization. + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. """ - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - weight_dtype: jnp.dtype = jnp.float32 - dtype: np.dtype = jnp.float32 - kernel_init: KernelInitializer = lecun_normal() - kernel_axes: Tuple[Optional[str], ...] = () - use_bias: bool = False - matmul_precision: str = "default" - - bias_init: Initializer = jax.nn.initializers.constant(0.0) - - @nn.compact - def __call__(self, inputs: jax.Array) -> jax.Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - - def compute_dot_general(inputs, kernel, axis, contract_ind): - """Computes a dot_general operation that may be quantized.""" - dot_general = jax.lax.dot_general - matmul_precision = jax.lax.Precision(self.matmul_precision) - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) - - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) - kernel = self.param( - "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - ) - kernel = jnp.asarray(kernel, self.dtype) - - contract_ind = tuple(range(0, len(axis))) - output = compute_dot_general(inputs, kernel, axis, contract_ind) - - if self.use_bias: - bias_axes, bias_shape = ( - self.kernel_axes[-len(features):], - kernel_shape[-len(features):], - ) - bias = self.param( - "bias", - nn.with_logical_partitioning(self.bias_init, bias_axes), - bias_shape, - self.weight_dtype, - ) - bias = jnp.asarray(bias, self.dtype) - - output += bias - return output + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + # kernel_in_axis = np.arange(len(axis)) + # kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index 882f21ace..aaed41048 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -7,99 +7,96 @@ class RepeatableCarryBlock(nn.Module): - """ - Integrates an input module in a jax carry format + """ + Integrates an input module in a jax carry format - ergo, the module assumes the role of a building block - and returns both input and output across all blocks - """ + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ - module: Callable[[Any], nn.Module] - module_init_args: List[Any] - module_init_kwargs: Dict[str, Any] + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] - @nn.compact - def __call__(self, *args) -> Tuple[jax.Array, None]: - """ - jax carry-op format of block - assumes the input contains an input tensor to the block along with kwargs that might be send to the block - kwargs are assumed to have static role, while the input changes between cycles + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles - Returns: - Tuple[jax.Array, None]: Output tensor from the block - """ - mod = self.module(*self.module_init_args, **self.module_init_kwargs) - output = mod(*args) - return output, None + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None class RepeatableLayer(nn.Module): - """ - RepeatableLayer will assume a similar role to torch.nn.ModuleList - with the condition that each block has the same graph, and only the parameters differ + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ - The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation - """ + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ - module: Callable[[Any], nn.Module] - """ + module: Callable[[Any], nn.Module] + """ A Callable function for single block construction """ - num_layers: int - """ + num_layers: int + """ The amount of blocks to build """ - module_init_args: List[Any] = field(default_factory=list) - """ + module_init_args: List[Any] = field(default_factory=list) + """ args passed to RepeatableLayer.module callable, to support block construction """ - module_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """ + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ kwargs passed to RepeatableLayer.module callable, to support block construction """ - pspec_name: Optional[str] = None - """ + pspec_name: Optional[str] = None + """ Partition spec metadata """ - param_scan_axis: int = 0 - """ + param_scan_axis: int = 0 + """ The axis that the "layers" will be aggragated on eg: if a kernel is shaped (8, 16) N layers will be (N, 8, 16) if param_scan_axis=0 and (8, N, 16) if param_scan_axis=1 """ - @nn.compact - def __call__(self, *args): - - scan_kwargs = {} - if self.pspec_name is not None: - scan_kwargs["metadata_params"] = { - nn.PARTITION_NAME: self.pspec_name} - - initializing = self.is_mutable_collection("params") - params_spec = self.param_scan_axis if initializing else partitioning.ScanIn( - self.param_scan_axis) - scan_fn = nn.scan( - RepeatableCarryBlock, - variable_axes={ - "params": params_spec, - "cache": 0, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, # Separate params per timestep - split_rngs={"params": True}, - in_axes=(nn.broadcast,) * (len(args) - 1), - length=self.num_layers, - **scan_kwargs, - ) - wrapped_function = scan_fn( - self.module, self.module_init_args, self.module_init_kwargs) - x, _ = wrapped_function(*args) - return x + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 3e1fd6d6e..4a78b48ea 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -22,155 +22,154 @@ @jax.jit def approximate_gelu(x: jax.Array) -> jax.Array: - """ - Computes Gaussian Error Linear Unit (GELU) activation function + """ + Computes Gaussian Error Linear Unit (GELU) activation function - Args: - x (jax.Array): The input tensor + Args: + x (jax.Array): The input tensor - jax.Array: The output tensor - """ - # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs - # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior - if x.dtype in (jax.numpy.float64,): - x = x.clip(-10, None) - return jax.nn.gelu(x, approximate=True) + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) def get_activation(act_fn: str): - """Returns the activation function from string.""" - act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] - raise ValueError(f"Unsupported activation function: {act_fn}") + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - approximate: str = "none" - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def gelu(self, gate: jax.Array) -> jax.Array: - approximate_to_tanh = self.approximate == "tanh" - if approximate_to_tanh: - return approximate_gelu(gate) - else: - return jax.nn.gelu(gate, approximate=False) - - @nn.compact - def __call__(self, hidden_states): - if self.approximate not in ("none", "tanh"): - raise ValueError( - f"approximate must be 'none' or 'tanh', got {self.approximate}") - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - hidden_states = proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states class GEGLU(nn.Module): - r""" - A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - proj = DenseGeneral( - features=self.dim_out * 2, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - - hidden_states = proj(hidden_states) - hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) - return hidden_states * jax.nn.gelu(gate, approximate=False) + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) class ApproximateGELU(nn.Module): - r""" - The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this - [paper](https://arxiv.org/abs/1606.08415). - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_in: int - dim_out: int - bias: bool = True - - kernel_axes: Tuple[Optional[str], ...] = () - kernel_init: KernelInitializer = lecun_normal() - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, x): - proj = DenseGeneral( - features=self.dim_out, - use_bias=self.bias, - kernel_axes=self.kernel_axes, - kernel_init=self.kernel_init, - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj", - ) - x = proj(x) - return x * jax.nn.sigmoid(1.702 * x) + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 374af6acc..4bc27e8bc 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -17,185 +17,177 @@ def get_timestep_embedding_multidim( scale: float = 1, max_period: int = 10000, ) -> jnp.ndarray: - """ - Computes sinusoidal timestep embeddings while preserving the original dimensions. - No reshaping to 1D is performed at any stage. - - Args: - timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. - embedding_dim (int): The dimension of the output. - flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) - or `sin, cos` (if False). - downscale_freq_shift (float): Controls the delta between frequencies between dimensions. - scale (float): Scaling factor applied to the embeddings. - max_period (int): Controls the maximum frequency of the embeddings. - - Returns: - jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. - """ - half_dim = embedding_dim // 2 - exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) - exponent = exponent / (half_dim - downscale_freq_shift) - shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) - emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape - emb = nn.with_logical_constraint( - emb, ("activation_batch", "activation_norm_length", "activation_embed")) - # Broadcasting to match shape (*timesteps.shape, half_dim) - emb = timesteps[..., None] * emb - emb = scale * emb - # Shape (*timesteps.shape, embedding_dim) - emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) - if flip_sin_to_cos: - emb = jnp.concatenate( - [emb[..., half_dim:], emb[..., :half_dim]], axis=-1) - - return emb + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = timesteps[..., None] * emb + emb = scale * emb + # Shape (*timesteps.shape, embedding_dim) + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb class TimestepEmbedding(nn.Module): - in_channels: int - time_embed_dim: int - act_fn: str = "silu" - out_dim: Optional[int] = None - sample_proj_bias: bool = True - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers efficiently""" - self.linear_1 = DenseGeneral( - self.time_embed_dim, - use_bias=self.sample_proj_bias, - kernel_axes=(None, "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - ) - - self.act = get_activation(self.act_fn) - time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim - self.linear_2 = DenseGeneral( - time_embed_dim_out, - use_bias=self.sample_proj_bias, - kernel_axes=("embed", "mlp"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - ) - - def __call__(self, sample, condition=None): - sample = nn.with_logical_constraint( - sample, ("activation_batch", "activation_norm_length", "activation_embed")) - sample = self.linear_1(sample) - sample = self.act(sample) - sample = self.linear_2(sample) - return sample + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample class Timesteps(nn.Module): - num_channels: int - flip_sin_to_cos: bool - downscale_freq_shift: float - scale: int = 1 - - def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: - t_emb = get_timestep_embedding_multidim( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - return t_emb + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - """ - - """ - - embedding_dim: int - size_emb_dim: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize sub-modules.""" - self.outdim = self.size_emb_dim - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, - time_embed_dim=self.embedding_dim, - name="timestep_embedder", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.astype(hidden_dtype)) - return timesteps_emb + """ """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb class AdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ + Compute AdaLayerNorm-Single modulation. - embedding_dim: int - embedding_coefficient: int = 6 - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - self.emb = AlphaCombinedTimestepSizeEmbeddings( - self.embedding_dim, - size_emb_dim=self.embedding_dim // 3, - name="emb", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - self.silu = jax.nn.silu - self.linear = DenseGeneral( - self.embedding_coefficient * self.embedding_dim, - use_bias=True, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear", - ) - - def __call__( - self, - timestep: jnp.ndarray, - added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, - batch_size: Optional[int] = None, - hidden_dtype: Optional[jnp.dtype] = None, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Compute AdaLayerNorm-Single modulation. - - Returns: - Tuple: - - Processed embedding after SiLU + linear transformation. - - Original embedded timestep. - """ - embedded_timestep = self.emb( - timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) - return self.linear(self.silu(embedded_timestep)), embedded_timestep + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4ade671c7..5d12e7813 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -25,921 +25,869 @@ class SkipLayerStrategy(Enum): - AttentionSkip = auto() - AttentionValues = auto() - Residual = auto() - TransformerBlock = auto() + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() class Identity(nn.Module): - def __call__(self, x): - return x + + def __call__(self, x): + return x class BasicTransformerBlock(nn.Module): - dim: int - num_attention_heads: int - attention_head_dim: int - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - attention_bias: bool = False - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - norm_elementwise_affine: bool = True - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" - norm_eps: float = 1e-5 - qk_norm: str = None - final_dropout: bool = False - attention_type: str = ("default",) # pylint: disable=unused-argument - ff_inner_dim: Optional[int] = None - ff_bias: bool = True - attention_out_bias: bool = True - use_tpu_flash_attention: bool = True - use_rope: bool = False - ffn_dim_mult: Optional[int] = 4 - attention_op: Optional[nn.Module] = None - sharding_mesh: Optional[jax.sharding.Mesh] = None - - dtype: jax.numpy.dtype = jnp.float32 - weight_dtype: jax.numpy.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - assert self.standardization_norm in ["layer_norm", "rms_norm"] - assert self.adaptive_norm in [ - "single_scale_shift", "single_scale", "none"] - assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." - - if self.standardization_norm == "layer_norm": - make_norm_layer = partial( - nn.LayerNorm, - epsilon=self.norm_eps, - param_dtype=self.weight_dtype, - dtype=self.dtype, - ) - else: - make_norm_layer = partial( - RMSNorm, - epsilon=self.norm_eps, - elementwise_affine=self.norm_elementwise_affine, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("norm",), - ) - - # 1. Self-Attn - self.norm1 = make_norm_layer(name="norm1") - self.attn1 = Attention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn1", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + # [batch, 1 or num_tokens, embedding_dim] + assert timestep.ndim == 3 + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = (jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2)) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint(attn_output, ("activation_batch", "activation_norm_length", "activation_embed")) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint(attn_input, ("activation_batch", "activation_norm_length", "activation_embed")) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint(ff_output, ("activation_batch", "activation_norm_length", "activation_embed")) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states - # 2. Cross-Attn - if self.cross_attention_dim is not None or self.double_self_attention: - self.attn2 = Attention( - query_dim=self.dim, - cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dropout=self.dropout, - bias=self.attention_bias, - upcast_attention=self.upcast_attention, - out_bias=self.attention_out_bias, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - attention_op=self.attention_op, - name="attn2", - dtype=self.dtype, - weight_dtype=self.weight_dtype, - ) - if self.adaptive_norm == "none": - self.attn2_norm = make_norm_layer() - else: - self.attn2 = None - self.attn2_norm = None - - self.norm2 = make_norm_layer(name="norm2") - # 3. Feed-forward - self.ff = FeedForward( - self.dim, - dropout=self.dropout, - activation_fn=self.activation_fn, - final_dropout=self.final_dropout, - inner_dim=self.ff_inner_dim, - bias=self.ff_bias, - mult=self.ffn_dim_mult, - name="ff", + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), dtype=self.dtype, weight_dtype=self.weight_dtype, + name="to_out.0", matmul_precision=self.matmul_precision, - ) - - # 4. Scale-Shift - if self.adaptive_norm != "none": - num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 - - def ada_initalizer(key): - return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), - ) - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - segment_ids: Optional[jnp.ndarray] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - encoder_attention_segment_ids: Optional[jnp.ndarray] = None, - timestep: Optional[jnp.ndarray] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[jnp.ndarray] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[SkipLayerStrategy] = None, - ) -> jnp.ndarray: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - print( - "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") - - hidden_states = nn.with_logical_constraint( - hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - hidden_states = checkpoint_name( - hidden_states, "basic_transformer_block hidden_states") - - batch_size = hidden_states.shape[0] - - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - - # Adaptive Norm - if self.adaptive_norm in ["single_scale_shift", "single_scale"]: - # [batch, 1 or num_tokens, embedding_dim] - assert timestep.ndim == 3 - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( - batch_size, timestep.shape[1], num_ada_params, -1 - ) - # Moving ada values to computation dtype to prevent dtype promotion - ada_values = ada_values.astype(self.dtype) - ada_values = nn.with_logical_constraint( - ada_values, ("activation_batch", "activation_norm_length", - "activation_ada", "activation_embed") - ) - - if self.adaptive_norm == "single_scale_shift": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) - ) - norm_hidden_states = norm_hidden_states * \ - (1 + scale_msa) + shift_msa - else: - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) - ) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) - elif self.adaptive_norm == "none": - scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - if norm_hidden_states.shape[1] == 1: - norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) - - # 1. Self-Attention - attn_output = self.attn1( - norm_hidden_states, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, - sharding_mesh=self.sharding_mesh, - skip_layer_mask=skip_layer_mask, - skip_layer_strategy=skip_layer_strategy, - **(cross_attention_kwargs or {}), - ) + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention(query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states - attn_output = nn.with_logical_constraint( - attn_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_msa is not None: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - - # 3. Cross-Attention - if self.attn2 is not None: - attn_input = self.attn2_norm( - hidden_states) if self.adaptive_norm == "none" else hidden_states - attn_input = nn.with_logical_constraint( - attn_input, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - attn_output = self.attn2( - attn_input, - freqs_cis=freqs_cis, - encoder_hidden_states=encoder_hidden_states, - segment_ids=segment_ids, - kv_attention_segment_ids=encoder_attention_segment_ids, - sharding_mesh=self.sharding_mesh, - **(cross_attention_kwargs or {}), - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-Forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = nn.with_logical_constraint( - norm_hidden_states, ("activation_batch", - "activation_norm_length", "activation_embed") - ) +class AttentionOp(nn.Module): - if self.adaptive_norm == "single_scale_shift": - norm_hidden_states = norm_hidden_states * \ - (1 + scale_mlp) + shift_mlp - elif self.adaptive_norm == "single_scale": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) - elif self.adaptive_norm == "none": - pass - else: - raise ValueError( - f"Unknown adaptive norm type: {self.adaptive_norm}") - - ff_output = self.ff(norm_hidden_states) - ff_output = nn.with_logical_constraint( - ff_output, ("activation_batch", - "activation_norm_length", "activation_embed") - ) - if gate_mlp is not None: - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = jnp.squeeze(hidden_states, axis=1) - hidden_states = nn.with_logical_constraint( - hidden_states, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - return hidden_states + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + qkvo_sharding_spec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert"), + ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + None, + ) + # Based on: ("activation_kv_batch", "activation_length") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can -class Attention(nn.Module): - query_dim: int - cross_attention_dim: Optional[int] = None - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - bias: bool = False - upcast_attention: bool = False - upcast_softmax: bool = False - cross_attention_norm: Optional[str] = None - added_kv_proj_dim: Optional[int] = None - out_bias: bool = True - scale_qk: bool = True - qk_norm: Optional[str] = None - only_cross_attention: bool = False - eps: float = 1e-5 - rescale_output_factor: float = 1.0 - residual_connection: bool = False - out_dim: Optional[int] = None - use_tpu_flash_attention: bool = True - use_rope: bool = False - attention_op: Optional[nn.Module] = None - - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - def setup(self): - """Initialize layers in Flax `setup()`.""" - self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads - self.use_bias = self.bias - self.is_cross_attention = self.cross_attention_dim is not None - self.fused_projections = False - out_dim = self.out_dim if self.out_dim is not None else self.query_dim - self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 - - # Query and Key Normalization - if self.qk_norm is None: - self.q_norm = Identity() - self.k_norm = Identity() - elif self.qk_norm == "rms_norm": - self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) - elif self.qk_norm == "layer_norm": - self.q_norm = nn.LayerNorm(epsilon=self.eps) - self.k_norm = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") - - if out_dim is not None: - self.heads_count = out_dim // self.dim_head - - # Validate parameters - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " - "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if self.cross_attention_norm is None: - self.norm_cross = None - elif self.cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(epsilon=self.eps) - else: - raise ValueError( - f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." - ) - - # Linear layers for queries, keys, values - self.to_q = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_q", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv"), - axis=-1, - ) - - if not self.only_cross_attention: - self.to_k = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_k", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - self.to_v = DenseGeneral( - features=(self.inner_dim,), - use_bias=self.bias, - name="to_v", - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - kernel_axes=("embed", "kv_head_dim"), - axis=-1, - ) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") - self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") - - self.to_out = [ - DenseGeneral( - features=(out_dim,), - use_bias=self.out_bias, - axis=-1, - kernel_axes=("kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="to_out.0", - matmul_precision=self.matmul_precision, - ), - nn.Dropout(self.dropout), - ] - - if self.attention_op is not None: - self.attention = self.attention_op - else: - _tpu_available = any( - device.platform == "tpu" for device in jax.devices()) - self.attention = AttentionOp() if _tpu_available else ExplicitAttention() - if not _tpu_available: - print( - "Warning: Running with explicit attention since tpu is not available.") - - def __call__( - self, - hidden_states: jnp.ndarray, - freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, - encoder_hidden_states: Optional[jnp.ndarray] = None, - segment_ids: Optional[jnp.ndarray] = None, - kv_attention_segment_ids: Optional[jnp.ndarray] = None, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - skip_layer_mask: Optional[jnp.ndarray] = None, - skip_layer_strategy: Optional[str] = None, - temb: Optional[jnp.ndarray] = None, - deterministic: bool = True, - **cross_attention_kwargs, - ) -> jnp.ndarray: - cross_attention_kwargs = { - k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - assert cross_attention_kwargs.get( - "scale", None) is None, "Not supported" - - input_axis_names = ("activation_batch", - "activation_length", "activation_embed") - hidden_states = nn.with_logical_constraint( - hidden_states, input_axis_names) - if encoder_hidden_states is not None: - encoder_hidden_states = nn.with_logical_constraint( - encoder_hidden_states, input_axis_names) - - residual = hidden_states - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = jnp.reshape( - hidden_states, (batch_size, channel, height * width)) - hidden_states = jnp.swapaxes(hidden_states, 1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - if skip_layer_mask is not None: - skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) - - query = self.to_q(hidden_states) - query = self.q_norm(query) - - if encoder_hidden_states is not None: - if self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states( - encoder_hidden_states) - key = self.to_k(encoder_hidden_states) - key = self.k_norm(key) - else: - encoder_hidden_states = hidden_states - key = self.to_k(hidden_states) - key = self.k_norm(key) - if self.use_rope: - key = apply_rotary_emb(key, freqs_cis) - query = apply_rotary_emb(query, freqs_cis) - - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) - query = jnp.swapaxes(query, 1, 2) - query = nn.with_logical_constraint( - query, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - query = checkpoint_name(query, "attention query") + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM - key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) - key = jnp.swapaxes(key, 1, 2) - key = nn.with_logical_constraint( - key, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - key = checkpoint_name(key, "attention key") + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size - value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) - value = jnp.swapaxes(value, 1, 2) - value = nn.with_logical_constraint( - value, ("activation_kv_batch", "activation_kv_heads", - "activation_length", "activation_kv_head_dim") - ) - value = checkpoint_name(value, "attention value") + ** SRAM cache size for TPU + V5P - 1MB SRAM per core - assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used - q_segment_ids = segment_ids - if q_segment_ids is not None: - q_segment_ids = q_segment_ids.astype(jnp.float32) + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) - if kv_attention_segment_ids is not None and q_segment_ids is None: - q_segment_ids = jnp.ones( - (batch_size, query.shape[2]), dtype=jnp.float32) - hidden_states_a = self.attention( - query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype - ) +class ExplicitAttention(nn.Module): - hidden_states_a: jax.Array = nn.with_logical_constraint( - hidden_states_a, ("activation_kv_batch", "activation_heads", - "activation_length", "activation_kv") - ) + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v - hidden_states_a = jnp.reshape(jnp.swapaxes( - hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: - hidden_states = hidden_states_a * skip_layer_mask + \ - hidden_states * (1.0 - skip_layer_mask) - else: - hidden_states = hidden_states_a - - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1]( - hidden_states, deterministic=deterministic) # Dropout - - if input_ndim == 4: - hidden_states = jnp.reshape(jnp.swapaxes( - hidden_states, -1, -2), (batch_size, channel, height, width)) - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - skip_layer_mask = jnp.reshape( - skip_layer_mask, (batch_size, 1, 1, 1)) - - if self.residual_connection: - if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: - hidden_states = hidden_states + residual * skip_layer_mask - else: - hidden_states = hidden_states + residual - - if self.rescale_output_factor != 1.0: - hidden_states = hidden_states / self.rescale_output_factor - hidden_states = checkpoint_name(hidden_states, "attention_output") - - return hidden_states - - def prepare_attention_mask( - self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 - ) -> jnp.ndarray: - head_size = self.heads_count - if attention_mask is None: - return attention_mask - - current_length = attention_mask.shape[-1] - if current_length != target_length: - remaining_length = target_length - current_length - attention_mask = jnp.pad( - attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = jnp.repeat(attention_mask, head_size, axis=0) - elif out_dim == 4: - attention_mask = jnp.expand_dims(attention_mask, axis=1) - attention_mask = jnp.repeat(attention_mask, head_size, axis=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: - assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) - else: - raise ValueError("Unknown normalization type for cross-attention.") - - return encoder_hidden_states +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. -class AttentionOp(nn.Module): - @nn.compact - def __call__( - self, - q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] - k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] - q_segment_ids: jax.Array, # [batch_size, q_tokens] - kv_segment_ids: jax.Array, # [batch_size, kv_tokens] - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - block_sizes: Optional[BlockSizes] = None, - ): - if block_sizes is None: - block_sizes = self.default_block_sizes(q, k, dtype) - - scale_factor = 1 / math.sqrt(q.shape[-1]) - - def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): - s = ( - # flash attention expects segment ids to be float32 - SegmentIds(q_segment_ids.astype(jnp.float32), - kv_segment_ids.astype(jnp.float32)) - if q_segment_ids is not None and kv_segment_ids is not None - else None - ) - output = jax_flash_attention( - q, - k, - v, - None, - s, - sm_scale=scale_factor, - block_sizes=block_sizes, - ) - return output - - if sharding_mesh is not None: - if q.ndim != 4: - raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") - if q_segment_ids is not None and q_segment_ids.ndim != 2: - raise ValueError( - f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") - # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") - # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. - qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), - None, - None, - ) - # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), "sequence") - wrapped_flash_attention = shard_map( - partial_flash_attention, - mesh=sharding_mesh, - in_specs=( - qkvo_sharding_spec, - qkvo_sharding_spec, - qkvo_sharding_spec, - qkv_segment_ids_spec, - qkv_segment_ids_spec, - ), - out_specs=qkvo_sharding_spec, - check_rep=False, - ) - else: - wrapped_flash_attention = partial_flash_attention - - return wrapped_flash_attention( - q, - k, - v, - q_segment_ids, - kv_segment_ids, - ) + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. - def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: - """ - Default block sizes for Flash Attention. - - TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM - we want to utilize the SRAM the best we can - - too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data - from the slower HBRAM - - a certain balance has to be met to get the best performance - imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) - along with the SRAM cache size - - ** SRAM cache size for TPU - V5P - 1MB SRAM per core - - Args: - q (jax.Array): Query tensor to be used - k (jax.Array): Key tensor to be used - - Returns: - BlockSizes: Grid block sizes - """ - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - return BlockSizes( - block_q=min(max_block_size, q.shape[-2]), - block_k_major=min(max_block_size, k.shape[-2]), - block_k=min(max_block_size, k.shape[-2]), - block_b=min(1, q.shape[0]), - block_q_major_dkv=min(max_block_size, q.shape[-2]), - block_k_major_dkv=min(max_block_size, k.shape[-2]), - block_q_dkv=min(max_block_size, q.shape[-2]), - block_k_dkv=min(max_block_size, k.shape[-2]), - block_q_dq=min(max_block_size, q.shape[-2]), - block_k_dq=min(512, k.shape[-2]), - block_k_major_dq=min(max_block_size, k.shape[-2]), - ) + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + Returns: + jax.Array: Normed data + """ -class ExplicitAttention(nn.Module): - def __call__( - self, - q: jax.Array, - k: jax.Array, - v: jax.Array, - q_segment_ids: jax.Array, - kv_segment_ids: jax.Array, - sharding_mesh: Optional[jax.sharding.Mesh] = None, - dtype: jnp.dtype = jnp.float32, - ): - assert sharding_mesh is None, "Explicit attention does not support sharding mesh." - attn_mask = None - if kv_segment_ids is not None: - q_segment_ids_expanded = q_segment_ids[:, None, :, None] - kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] - attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded - - scale_factor = 1 / jnp.sqrt(q.shape[-1]) - attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == jnp.bool_: - attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = q @ k.swapaxes(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = jnn.softmax(attn_weight, axis=-1) - - return attn_weight @ v + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) -class RMSNorm(nn.Module): - """ - RMSNorm is a normalization layer that normalizes the input using the root mean square. - """ + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) - epsilon: float - dtype: jnp.dtype = jnp.float32 - elementwise_affine: bool = True - weight_dtype: jnp.dtype = jnp.float32 - kernel_axes: Tuple[Optional[str], ...] = () - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, hidden_states: jax.Array) -> jax.Array: - """ - Forward pass of the RMSNorm layer. - - First we compute the variance (mean of the square of the input) - and then normalize the input using the root mean square. - - NOTE: if weight is in mixed precision, the operand should be in the same precision. - Args: - hidden_states (jax.Array): Input data - - Returns: - jax.Array: Normed data - """ - - # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim - dim = hidden_states.shape[-1] - if self.elementwise_affine: - scale = self.param( - "scale", - nn.with_logical_partitioning( - self.scale_init, self.kernel_axes), - (dim,), - self.weight_dtype, - ) - else: - scale = None - - input_dtype = hidden_states.dtype - variance = jnp.mean(jnp.square(hidden_states.astype( - jnp.float32)), axis=-1, keepdims=True) - hidden_states: jax.Array = hidden_states * \ - jax.lax.rsqrt(variance + self.epsilon) - - if self.elementwise_affine: - # convert into half-precision if necessary - hidden_states = (hidden_states.astype(self.dtype) - * scale.astype(self.dtype)).astype(input_dtype) - else: - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states + return hidden_states class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - dim_out: Optional[int] = None - mult: int = 4 - dropout: float = 0.0 - activation_fn: str = "gelu" - final_dropout: bool = False - bias: bool = True - inner_dim: Optional[int] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - - @nn.compact - def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: - dim = hidden_states.shape[-1] - if self.inner_dim is None: - inner_dim = dim * self.mult - if inner_dim < 256: - raise ValueError("inner_dim must be at least 256") - # round to nearest multiple of 256 - inner_dim = round(inner_dim / 256) * 256 - else: - inner_dim = self.inner_dim - - dim_out = self.dim_out if self.dim_out is not None else dim - - act_kwargs = { - "name": "net.0", - "bias": self.bias, - "kernel_axes": ("embed", "mlp"), - "matmul_precision": self.matmul_precision, - "weight_dtype": self.weight_dtype, - "dtype": self.dtype, - } - match self.activation_fn: - case "gelu": - act_fn = GELU(dim, inner_dim, **act_kwargs) - case "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) - case "geglu": - act_fn = GEGLU(dim, inner_dim, **act_kwargs) - case "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) - case _: - raise ValueError( - f"activation function {self.activation_fn} not supported") - - if isinstance(act_fn, GEGLU): - hidden_states = act_fn(hidden_states, scale) - else: - hidden_states = act_fn(hidden_states) - - hidden_states = checkpoint_name(hidden_states, "FFN - activation") - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - hidden_states = DenseGeneral( - dim_out, - use_bias=self.bias, - kernel_axes=("mlp", "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="net.2", - )(hidden_states) - hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") - if self.final_dropout: - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - hidden_states = nn.Dropout(self.dropout)( - hidden_states, deterministic=deterministic) - - return hidden_states + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + # round to nearest multiple of 256 + inner_dim = round(inner_dim / 256) * 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: - """ - Integrates positional information into input tensors using RoPE. + """ + Integrates positional information into input tensors using RoPE. - Args: - input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) - freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies - Returns: - jax.Array: Tensor where positional information has been integrated into the original input tensor - """ - if len(freqs_cis) != 2: - raise ValueError("freqs_cis must be a tuple of 2 elements") + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") - cos_freqs, sin_freqs = freqs_cis + cos_freqs, sin_freqs = freqs_cis - t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) - t1, t2 = jnp.split(t_dup, 2, axis=-1) - t_dup = jnp.concatenate([-t2, t1], axis=-1) - input_tensor_rot = t_dup.reshape(*input_tensor.shape) + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) - # Apply rotary embeddings - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index dff8b8c62..f2b1af101 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -6,35 +6,35 @@ class CaptionProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - """ + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ - in_features: int - hidden_size: int - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" - @nn.compact - def __call__(self, caption): - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_1", - )(caption) - hidden_states = approximate_gelu(hidden_states) - hidden_states = DenseGeneral( - self.hidden_size, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="linear_2", - )(hidden_states) - return hidden_states + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 4368c35fb..dac8e6280 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -13,310 +13,297 @@ class Transformer3DModel(nn.Module): - num_attention_heads: int = 16 - attention_head_dim: int = 88 - out_channels: int = 128 - num_layers: int = 1 - dropout: float = 0.0 - cross_attention_dim: Optional[int] = None - attention_bias: bool = False - activation_fn: str = "geglu" - num_embeds_ada_norm: Optional[int] = None - only_cross_attention: bool = False - double_self_attention: bool = False - upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" - standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' - norm_elementwise_affine: bool = True - norm_eps: float = 1e-5 - attention_type: str = "default" - caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True - qk_norm: Optional[str] = None - positional_embedding_type: str = "rope" - positional_embedding_theta: Optional[float] = None - positional_embedding_max_pos: Optional[List[int]] = None - timestep_scale_multiplier: Optional[float] = None - ffn_dim_mult: Optional[int] = 4 - output_scale: Optional[float] = None - attention_op: Optional[nn.Module] = None - dtype: jnp.dtype = jnp.float32 - weight_dtype: jnp.dtype = jnp.float32 - matmul_precision: str = "default" - sharding_mesh: Optional[jax.sharding.Mesh] = None - param_scan_axis: int = 0 - gradient_checkpointing: Optional[str] = None - - def setup(self): - assert self.out_channels is not None, "out channels must be specified in model config." - self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.patchify_proj = DenseGeneral( - self.inner_dim, - use_bias=True, - kernel_axes=(None, "embed"), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="patchify_proj", - ) - self.freq_cis_pre_computer = FreqsCisPrecomputer( - self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim - ) - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, - embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def scale_shift_table_init(key): - return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 - - self.scale_shift_table = self.param( - "scale_shift_table", # Trainable parameter name - nn.with_logical_partitioning( - scale_shift_table_init, ("ada", "embed")), - ) - self.norm_out = nn.LayerNorm( - epsilon=1e-6, use_scale=False, use_bias=False) - self.proj_out = DenseGeneral( - self.out_channels, - use_bias=True, - kernel_axes=("embed", None), - matmul_precision=self.matmul_precision, - weight_dtype=self.weight_dtype, - dtype=self.dtype, - name="proj_out", - ) - self.use_rope = self.positional_embedding_type == "rope" - if self.num_layers > 0: - RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( - BasicTransformerBlock - ) - - self.transformer_blocks = RepeatableLayer( - RemattedBasicTransformerBlock, - num_layers=self.num_layers, - module_init_kwargs=dict( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - dropout=self.dropout, - cross_attention_dim=self.cross_attention_dim, - activation_fn=self.activation_fn, - num_embeds_ada_norm=self.num_embeds_ada_norm, - attention_bias=self.attention_bias, - only_cross_attention=self.only_cross_attention, - double_self_attention=self.double_self_attention, - upcast_attention=self.upcast_attention, - adaptive_norm=self.adaptive_norm, - standardization_norm=self.standardization_norm, - norm_elementwise_affine=self.norm_elementwise_affine, - norm_eps=self.norm_eps, - attention_type=self.attention_type, - use_tpu_flash_attention=self.use_tpu_flash_attention, - qk_norm=self.qk_norm, - use_rope=self.use_rope, - ffn_dim_mult=self.ffn_dim_mult, - attention_op=self.attention_op, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - sharding_mesh=self.sharding_mesh, - name="CheckpointBasicTransformerBlock_0", - ), - pspec_name="layers", - param_scan_axis=self.param_scan_axis, - ) - - if self.caption_channels is not None: - self.caption_projection = CaptionProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - matmul_precision=self.matmul_precision, - ) - - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal( - split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal( - split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal( - split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal( - split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) - if eval_only: - return jax.eval_shape( - self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] - - def __call__( - self, - hidden_states, - indices_grid, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - segment_ids=None, - encoder_attention_segment_ids=None, - return_dict=True, - ): - hidden_states = self.patchify_proj(hidden_states) - freqs_cis = self.freq_cis_pre_computer(indices_grid) - - if self.timestep_scale_multiplier: - timestep = self.timestep_scale_multiplier * timestep - - batch_size = hidden_states.shape[0] - - timestep, embedded_timestep = self.adaln_single( - timestep, - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection( - encoder_hidden_states) - - if self.num_layers > 0: - hidden_states = self.transformer_blocks( - hidden_states, - freqs_cis, - segment_ids, - encoder_hidden_states, - encoder_attention_segment_ids, - timestep, - cross_attention_kwargs, - class_labels, - ) - # Output processing - - scale_shift_values = ( - self.scale_shift_table[jnp.newaxis, jnp.newaxis, - :, :] + embedded_timestep[:, :, jnp.newaxis] - ) - scale_shift_values = nn.with_logical_constraint( - scale_shift_values, ("activation_batch", "activation_length", - "activation_ada", "activation_embed") - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - if self.output_scale: - hidden_states = hidden_states / self.output_scale - - return hidden_states + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + # 'single_scale_shift' or 'single_scale' + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + # if True uses the TPU attention offload ('flash attention') + use_tpu_flash_attention: bool = True + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): + + # bookkeeping, for convenient changes later + latents_shape = (batch_size, num_tokens, features) + fractional_cords_shape = (batch_size, 3, num_tokens) + prompt_embeds_shape = (batch_size, text_tokens, features) + noise_cond_shape = (batch_size, 1) + latents_dtype = jnp.bfloat16 + fractional_coords_dtype = jnp.bfloat16 + prompt_embeds_dtype = jnp.bfloat16 + noise_cond_dtype = jnp.bfloat16 + + # initialize to random + key, split_key = jax.random.split(key) + prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) + key, split_key = jax.random.split(key) + fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) + key, split_key = jax.random.split(key) + latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) + key, split_key = jax.random.split(key) + noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) + + key, split_key = jax.random.split(key) + if eval_only: + return jax.eval_shape( + self.init, + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + else: + return self.init( + rngs={"params": split_key}, + hidden_states=latents, + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + )["params"] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states def log_base(x: jax.Array, base: jax.Array) -> jax.Array: - """ - Computes log of x with defined base. + """ + Computes log of x with defined base. - Args: - x (jax.Array): log value - base (jax.Array): base of the log + Args: + x (jax.Array): log value + base (jax.Array): base of the log - Returns: - jax.Array: log(x)[base] - """ - return jnp.log(x) / jnp.log(base) + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) class FreqsCisPrecomputer(nn.Module): - """ - computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. - This is commonly used in rotary embeddings (RoPE) for transformers. - """ - - positional_embedding_max_pos: List[int] - positional_embedding_theta: float - inner_dim: int - - def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: - fractional_positions = jnp.stack( - [indices_grid[:, i] / self.positional_embedding_max_pos[i] - for i in range(3)], - axis=-1, - ) - return fractional_positions - - @nn.compact - def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: - source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 - dim = self.inner_dim - theta = self.positional_embedding_theta - - fractional_positions = self.get_fractional_positions(indices_grid) - - start = 1 - end = theta - indices = jnp.power( - theta, - jnp.linspace( - log_base(start, theta), - log_base(end, theta), - dim // 6, - dtype=dtype, - ), - ) - indices = indices.astype(dtype) - - indices = indices * jnp.pi / 2 - - freqs = (indices * (jnp.expand_dims(fractional_positions, - axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) - - cos_freq = jnp.cos(freqs).repeat(2, axis=-1) - sin_freq = jnp.sin(freqs).repeat(2, axis=-1) - - if dim % 6 != 0: - cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) - - cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) - return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + # We need full precision in the freqs_cis computation. + dtype = jnp.float32 + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + # Flatten along axis 2 + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) From 7e098c586fad8874fa0d62912a42a11f159c9545 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Thu, 26 Jun 2025 22:56:13 +0000 Subject: [PATCH 04/25] format fixed --- src/maxdiffusion/configs/ltx_video.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 10 +-- .../ltx_video/transformers/attention.py | 2 +- .../ltx_video/transformers/transformer3d.py | 68 +++++++------------ .../ltx_video/xora_v1.2-13B-balanced-128.json | 3 +- 5 files changed, 33 insertions(+), 52 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 954922521..eb44d253b 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -62,4 +62,4 @@ cache_latents_text_encoder_outputs: True per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 -jit_initializers: True \ No newline at end of file +jit_initializers: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index d05203f5c..6efe564b2 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,11 +20,13 @@ import json from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os +import functools import jax.numpy as jnp from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, ) +from jax.sharding import Mesh def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): @@ -38,7 +40,7 @@ def run(config): key = jax.random.PRNGKey(0) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) @@ -49,12 +51,10 @@ def run(config): model_config = json.load(f) transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) + transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 key, split_key = jax.random.split(key) - - - weights_init_fn = functools.partial( + weights_init_fn = functools.partial( # noqa F841 transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True ) diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 5d12e7813..4812b89ba 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -438,7 +438,7 @@ def __call__( deterministic: bool = True, **cross_attention_kwargs, ) -> jnp.ndarray: - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 assert cross_attention_kwargs.get("scale", None) is None, "Not supported" input_axis_names = ("activation_batch", "activation_length", "activation_embed") diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index dac8e6280..cf599f26c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module): only_cross_attention: bool = False double_self_attention: bool = False upcast_attention: bool = False - # 'single_scale_shift' or 'single_scale' - adaptive_norm: str = "single_scale_shift" + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-5 attention_type: str = "default" caption_channels: int = None - # if True uses the TPU attention offload ('flash attention') - use_tpu_flash_attention: bool = True + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') qk_norm: Optional[str] = None positional_embedding_type: str = "rope" positional_embedding_theta: Optional[float] = None @@ -98,7 +96,7 @@ def scale_shift_table_init(key): self.transformer_blocks = RepeatableLayer( RemattedBasicTransformerBlock, num_layers=self.num_layers, - module_init_kwargs=dict( + module_init_kwargs=dict( # noqa C408 dim=self.inner_dim, num_attention_heads=self.num_attention_heads, attention_head_dim=self.attention_head_dim, @@ -139,46 +137,30 @@ def scale_shift_table_init(key): matmul_precision=self.matmul_precision, ) - def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True): - - # bookkeeping, for convenient changes later - latents_shape = (batch_size, num_tokens, features) - fractional_cords_shape = (batch_size, 3, num_tokens) - prompt_embeds_shape = (batch_size, text_tokens, features) - noise_cond_shape = (batch_size, 1) - latents_dtype = jnp.bfloat16 - fractional_coords_dtype = jnp.bfloat16 - prompt_embeds_dtype = jnp.bfloat16 - noise_cond_dtype = jnp.bfloat16 - - # initialize to random - key, split_key = jax.random.split(key) - prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype) - key, split_key = jax.random.split(key) - fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype) - key, split_key = jax.random.split(key) - latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype) - key, split_key = jax.random.split(key) - noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype) - - key, split_key = jax.random.split(key) + def init_weights(self, in_channels, key, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + if eval_only: return jax.eval_shape( self.init, - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, + key, + **example_inputs, )["params"] else: - return self.init( - rngs={"params": split_key}, - hidden_states=latents, - indices_grid=fractional_coords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - )["params"] + return self.init(key, **example_inputs)["params"] def __call__( self, @@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: @nn.compact def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: source_dtype = indices_grid.dtype - # We need full precision in the freqs_cis computation. - dtype = jnp.float32 + dtype = jnp.float32 # We need full precision in the freqs_cis computation. dim = self.inner_dim theta = self.positional_embedding_theta @@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: indices = indices * jnp.pi / 2 freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) - # Flatten along axis 2 - freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 cos_freq = jnp.cos(freqs).repeat(2, axis=-1) sin_freq = jnp.sin(freqs).repeat(2, axis=-1) diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 02f13b15a..75b16b011 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -20,5 +20,6 @@ "positional_embedding_type": "rope", "positional_embedding_theta": 10000.0, "positional_embedding_max_pos": [20, 2048, 2048], - "timestep_scale_multiplier": 1000 + "timestep_scale_multiplier": 1000, + "in_channels": 128 } \ No newline at end of file From e18128c3f19db8a8cba15195397281addaf0558a Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 18:17:48 +0000 Subject: [PATCH 05/25] transformer step and test --- src/maxdiffusion/configs/ltx_video.yml | 24 ++- src/maxdiffusion/generate_ltx_video.py | 171 ++++++++++++--- src/maxdiffusion/max_utils.py | 21 +- .../ltx_video/xora_v1.2-13B-balanced-128.json | 1 + src/maxdiffusion/pyconfig.py | 17 ++ .../tests/ltx_transformer_step_test.py | 198 ++++++++++++++++++ .../tests/ltx_vid_transformer_test_ref_pred | Bin 0 -> 263834 bytes 7 files changed, 402 insertions(+), 30 deletions(-) create mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py create mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index eb44d253b..8f1ee8a7d 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,12 +22,25 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', ['data','fsdp']], @@ -40,13 +53,19 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 + ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 @@ -63,3 +82,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6efe564b2..bad791cee 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,23 +1,8 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - from absl import app from typing import Sequence import jax import json +from flax.linen import partitioning as nn_partitioning from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import os import functools @@ -25,39 +10,171 @@ from maxdiffusion import pyconfig from maxdiffusion.max_utils import ( create_device_mesh, + setup_initial_state, + get_memory_allocations, ) -from jax.sharding import Mesh +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): +def validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids +): print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) print("latents.shape: ", latents.shape, latents.dtype) print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return noise_pred def run(config): - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 + mesh = Mesh(devices_array, config.mesh_axes) - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 base_dir = os.path.dirname(__file__) - # load in model config + ##load in model config config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) - key, split_key = jax.random.split(key) - weights_init_fn = functools.partial( # noqa F841 - transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + + # create dummy inputs: + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + validate_transformer_inputs( + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids + ) + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, ) + noise_pred = p_run_inference(states).block_until_ready() + print(noise_pred) # (4, 256, 128) + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..f3f5148b2 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if config.ici_fsdp_transpose_parallelism: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -402,7 +417,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json index 75b16b011..c5b3c0ef9 100644 --- a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -1,4 +1,5 @@ { + "ckpt_path": "", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 128, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 67437ba0b..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -154,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py new file mode 100644 index 000000000..b0a266b70 --- /dev/null +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -0,0 +1,198 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import torch +import jax +import numpy as np +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from jax.sharding import Mesh +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import functools +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import PartitionSpec as P +import orbax.checkpoint as ocp + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_ref_prediction(): + saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + predict_dict = torch.load(saved_prediction_path) + noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) + return noise_pred_pt + + +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): + latents, state, noise_cond = args + noise_pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + indices_grid=fractional_cords, + encoder_hidden_states=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + return noise_pred, state, noise_cond + + +def run_inference( + states, + transformer, + config, + mesh, + latents, + fractional_cords, + prompt_embeds, + timestep, + segment_ids, + encoder_attention_segment_ids, +): + transformer_state = states["transformer"] + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + fractional_cords=fractional_cords, + prompt_embeds=prompt_embeds, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) + return latents + + +class LTXTransformerTest(unittest.TestCase): + + def test_one_step_transformer(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + noise_pred_pt = load_ref_prediction() + + # set up transformer + key = jax.random.PRNGKey(42) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + ignored_keys = [ + "_class_name", + "_diffusers_version", + "_name_or_path", + "causal_temporal_positioning", + "in_channels", + "ckpt_path", + ] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + transformer = Transformer3DModel( + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh + ) + weights_init_fn = functools.partial( + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + states["transformer"] = transformer_state + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "latents": (batch_size, num_tokens, in_channels), + "fractional_coords": (batch_size, 3, num_tokens), + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(example_inputs["latents"], data_sharding) + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + config=config, + mesh=mesh, + latents=latents, + fractional_cords=fractional_coords, + prompt_embeds=prompt_embeds, + timestep=noise_cond, + segment_ids=segment_ids, + encoder_attention_segment_ids=encoder_attention_segment_ids, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + noise_pred = p_run_inference(states).block_until_ready() + noise_pred = torch.from_numpy(np.array(noise_pred)) + + torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred new file mode 100644 index 0000000000000000000000000000000000000000..0a9fe912036cf35e35d5d8127cff178e1b4a9399 GIT binary patch literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C literal 0 HcmV?d00001 From 1c554523fbeddc9d3f85f59ed2e23c36a32356be Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:24:26 +0000 Subject: [PATCH 06/25] removed diffusers import --- .../models/ltx_video/transformers/activations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 4a78b48ea..8e7ffb321 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -5,8 +5,6 @@ from flax import linen as nn from flax.linen.initializers import lecun_normal -from diffusers.utils.deprecation_utils import deprecate - from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer @@ -117,9 +115,9 @@ class GEGLU(nn.Module): @nn.compact def __call__(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) + # if len(args) > 0 or kwargs.get("scale", None) is not None: + # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + # deprecate("scale", "1.0.0", deprecation_message) proj = DenseGeneral( features=self.dim_out * 2, From fd4af91ddb23b73dd7fb58958f49fb24bd938db2 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Mon, 30 Jun 2025 21:30:45 +0000 Subject: [PATCH 07/25] fixed mesh --- src/maxdiffusion/max_utils.py | 63 +++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index f3f5148b2..c48b7da0f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True): devices = jax.devices() num_devices = len(devices) ##special case for ltx-video - if config.ici_fsdp_transpose_parallelism: + if "fsdp_transpose" in config.mesh_axes: num_slices = 1 # if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -628,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file From 5e17a62648b1d5d26198b3b1adea6cb89e0eb904 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 01:05:14 +0000 Subject: [PATCH 08/25] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index b0a266b70..d0f6c2e1d 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,7 @@ def load_ref_prediction(): - saved_prediction_path = "ltx_vid_transformer_test_ref_pred" + saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From fc60b27b30e2988a4e4fad2d0252a0f044ded270 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 03:44:05 +0000 Subject: [PATCH 09/25] changed path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d0f6c2e1d..d40c932ba 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -39,7 +39,8 @@ def load_ref_prediction(): - saved_prediction_path = "../ltx_vid_transformer_test_ref_pred" + base_dir = os.path.dirname(__file__) + saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") predict_dict = torch.load(saved_prediction_path) noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) return noise_pred_pt From 3243535216e0348c36ddd7c3e90316c66343026d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:33:16 +0000 Subject: [PATCH 10/25] changed config path --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index d40c932ba..43b15ba72 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -103,7 +103,9 @@ def test_one_step_transformer(self): key = jax.random.PRNGKey(42) devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - config_path = "../models/ltx_video/xora_v1.2-13B-balanced-128.json" + base_dir = os.path.dirname(__file__) + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From e873a17c795037318b46222c86137bb7899bb43d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Tue, 1 Jul 2025 04:44:51 +0000 Subject: [PATCH 11/25] ruff check --- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 43b15ba72..61c6909c0 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -105,7 +105,7 @@ def test_one_step_transformer(self): mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - + with open(config_path, "r") as f: model_config = json.load(f) relative_ckpt_path = model_config["ckpt_path"] From d06dee301231259fa83f3e7849d4af0f8655d8bf Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 17:55:48 +0000 Subject: [PATCH 12/25] changed back pyconfig --- src/maxdiffusion/max_utils.py | 78 +---------------------------------- src/maxdiffusion/pyconfig.py | 41 ++++++++++-------- 2 files changed, 24 insertions(+), 95 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +288,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -474,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..f4e1900aa 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): @@ -221,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) + r = range(config.steps) \ No newline at end of file From aa7befd137cbb9bd28db098d23af7ab75de37b5d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 2 Jul 2025 21:38:05 +0000 Subject: [PATCH 13/25] changed sharding back --- src/maxdiffusion/configs/ltx_video.yml | 8 +++++--- src/maxdiffusion/max_utils.py | 2 +- .../models/ltx_video/transformers/attention.py | 13 ++++++++++--- src/maxdiffusion/pyconfig.py | 2 +- src/maxdiffusion/tests/ltx_transformer_step_test.py | 2 +- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 8f1ee8a7d..d29707537 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -40,20 +40,22 @@ output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] +mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], + ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], + ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] +data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..fab895f97 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -609,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 4812b89ba..e9d9d932d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -622,14 +622,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) qkvo_sharding_spec = jax.sharding.PartitionSpec( - ("data", "fsdp", "fsdp_transpose", "expert"), - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + None, + ("data", "fsdp", "tensor"), None, None, ) # Based on: ("activation_kv_batch", "activation_length") - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkv_segment_ids_spec = jax.sharding.PartitionSpec("fsdp", None) + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) wrapped_flash_attention = shard_map( partial_flash_attention, mesh=sharding_mesh, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index f4e1900aa..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -226,4 +226,4 @@ def initialize(argv, **kwargs): if __name__ == "__main__": initialize(sys.argv) print(config.steps) - r = range(config.steps) \ No newline at end of file + r = range(config.steps) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 61c6909c0..9a816d6e5 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -104,7 +104,7 @@ def test_one_step_transformer(self): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") with open(config_path, "r") as f: model_config = json.load(f) From d9a35020c13df74bd91ac3e74a111d8a1e82673d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Sat, 5 Jul 2025 23:07:45 +0000 Subject: [PATCH 14/25] removed testing for now --- src/maxdiffusion/max_utils.py | 61 +----- .../tests/ltx_transformer_step_test.py | 201 ------------------ .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 3 files changed, 2 insertions(+), 260 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index c48b7da0f..d4a80a347 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh - + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") - - - - - - - - - - - - - - - - - - return mesh - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 61c6909c0..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From a1ad421eda51fa42d34bc4667ff32f107d9d36e8 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:25 -0700 Subject: [PATCH 15/25] Update pyconfig.py --- src/maxdiffusion/pyconfig.py | 39 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index edcf96164..af6493ea2 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,7 +25,6 @@ import yaml from . import max_logging from . import max_utils -from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -42,6 +41,21 @@ def string_to_bool(s: str) -> bool: config = None +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_sequence_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + return raw_keys + + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -103,7 +117,6 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) - _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -112,26 +125,6 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict - @staticmethod - def wan_init(raw_keys): - if "wan_transformer_pretrained_model_name_or_path" in raw_keys: - transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] - if transformer_pretrained_model_name_or_path == "": - raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] - elif ( - transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH - or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH - ): - # Set correct parameters for CausVid in case of user error. - raw_keys["guidance_scale"] = 1.0 - num_inference_steps = raw_keys["num_inference_steps"] - if num_inference_steps > 10: - max_logging.log( - f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." - ) - else: - raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") - @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -176,6 +169,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "ici_fsdp_transpose_parallelism" in raw_keys: + raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): From 615174f94f66ebd5e19c98d1cac2b0bf242b70de Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:15:57 -0700 Subject: [PATCH 16/25] Update max_utils.py --- src/maxdiffusion/max_utils.py | 78 ++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..51de312ca 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) + ##special case for ltx-video + if "fsdp_transpose" in config.mesh_axes: + num_slices = 1 + # if config.inference_benchmark_test else config.num_slices + num_devices_per_slice = num_devices // num_slices + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True): if logging: max_logging.log(f"Decided on mesh: {mesh}") + + + + + + + + + + + + + + + + + + return mesh + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState): """Unboxes the flax.LogicallyPartitioned pieces in a train state. @@ -402,7 +474,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( From 7469c62c97f9f711114111490f9220a433b18d8a Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:26 -0700 Subject: [PATCH 17/25] Update ltx_video.yml --- src/maxdiffusion/configs/ltx_video.yml | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index d29707537..87d0e9bea 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,47 +22,32 @@ weights_dtype: 'bfloat16' activations_dtype: 'bfloat16' -run_name: '' -output_dir: 'ltx-video-output' -save_config_to_gcs: False - -#hardware -hardware: 'tpu' -skip_jax_distributed_system: False - -jax_cache_dir: '' -weights_dtype: 'bfloat16' -activations_dtype: 'bfloat16' - - run_name: '' output_dir: 'ltx-video-output' save_config_to_gcs: False #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], - ['norm', 'fsdp'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 - ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 + ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_tensor_transpose_parallelism: 1 @@ -84,4 +69,4 @@ per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 jit_initializers: True -enable_single_replica_ckpt_restoring: False \ No newline at end of file +enable_single_replica_ckpt_restoring: False From 6de4424170129d21fb4c11b38a62cccfc8e9a0d2 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:16:51 -0700 Subject: [PATCH 18/25] Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred --- .../tests/ltx_vid_transformer_test_ref_pred | Bin 263834 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred diff --git a/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred b/src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred deleted file mode 100644 index 0a9fe912036cf35e35d5d8127cff178e1b4a9399..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 263834 zcmeI*O>7)z835q1n;(?UiS& zPwV}ecV2%pws&W~?|tUWJNklPa4>ja*%RyxT8n3srL|hpJf5`D(Su8sv@*PUt~NU} z5S+TZT_BzEJMD-1x+I7Z>ZbTC$i{>np9Hx#m)m=k(?3 zVmj0q`ogis&b0B#V~tO>hUd!zgJV~&ow%OIE!zCHU9m4X)ZG=s&(xdCm2}~J(ro41 znVp%q*CU%9^C53WiY#8e_0GlOr!3m`Sv(e$9>`*|`xYv->Y0`0WF?)QdHf?KGC5u( z5I@^^uQ%U#(Y8-uZ!q1R!0G4Gq+ay!?9BK_+U~vXsa4aomJ~T$3g3OSKi_`Qw$EW# zFx8zyaa0~G&CZN{oCG#HX|^6;jc$f>@pAiIRPB7P{kLd6zJ25Q=qK5e*~Mrn>g`-= z|K`T;qHjf?&EC%5I#SMdbp|{C2}iR_(VlEI+#6-tx1(#}<#0M$?))MAQWixoMK8p! zXH)H$qVIGj8<>=`nE1kcGk7R!hZ-l>$z8)`!x1%?sSF*#Mnf4nu zHlp7~UyFBjj%7c~D)Ch3?d%^}Z}!vdcKcLxF`SH!#(&Pnvqw7Tqt~)OwO@?V?0DzX z?LTI%cp~=m+y9P#zyJRJ`TOtxpa1{f|9JoB{jc}`KL7aq=ku@6|IU9n|Kt3Z^MB5N zI{)kZxAXt5f4KhR`j_i}u7A4z>-x9r|HeNU|6%-#@ju2t8UJPcoAH0fKN|mO{HyW5 z#y=bXZT!3O|K>lK|6%@%`9J1Ang3<}oB4m{Kbrq({;T=F=0BVNZT`FY|LPy8|DgVb z`XB0_sQ;q=jru?8AF2PO{+0S)>Yu6qrv9D!f9fBq|ET_@`k(5bs{g9~t@^*}AFKbY z{YuCsuKvCH|NH~~1OI~m!9U@@@Nf7({3HGo|BC;`KjXjg@A!ZGL;fTGlK;s+ z<-hW8`M>;Q{xkoY|II(=zqjIvc%#FA=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj z|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc z`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$ScmDfC zywOQB{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7 zo&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR z-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4 z|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCscH#%vy9`oP%@BDZEJO7>k&VT2> z^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZE zJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBs zzw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9 z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|DI;+@zsd`&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~% z=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BH`m_-b@B zk&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7{c3bGoQwJI{CEC4|DFHNf9JpR-}&$S zcm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VRod&c(}Z{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHN zf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k&VT2>^WXXJ z{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu@BDZEJO7>k z&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc`S1L9{yYDj|IUBszw_Vu z@BDZEJO7>k&VT2>^WXXJ{CEC4|DFHNf9JpR-}&$Scm6y7o&U~%=fCsc=i=q|xrqPH zf9JpR-}&$S_x}g~eZD*}__?=!G!X=SLDA;RJMwP^2lG2!%kJPv&{{m3EUnd&=JBMJ zjvhRpG+WjB%IwT^aOz_&V?#v-zkm76iQasBMcY1u{lOQzGZh{5qtm0Mkh-YwxC&bQK)+Ue)gq~2Ugn%ziE`z^D0g;ksXN5|hsy3PAwH8I zxm}5D(>M3;$o*Sl{NTYDAAE?-m;3u4-&}Ryjwjv^GO+pGz1us7bpF_MU-A1@95?-U ao|^8zd%xJ#^W}$%ZVd+Vuj2K0U;97(NJu6C From 18ec247aa0b181f1ded5642093027d1cce109b3e Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Sat, 5 Jul 2025 16:17:01 -0700 Subject: [PATCH 19/25] Delete src/maxdiffusion/tests/ltx_transformer_step_test.py --- .../tests/ltx_transformer_step_test.py | 201 ------------------ 1 file changed, 201 deletions(-) delete mode 100644 src/maxdiffusion/tests/ltx_transformer_step_test.py diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py deleted file mode 100644 index 9a816d6e5..000000000 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ /dev/null @@ -1,201 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import os -import torch -import jax -import numpy as np -import jax.numpy as jnp -import unittest -from absl.testing import absltest -from jax.sharding import Mesh -import json -from flax.linen import partitioning as nn_partitioning -from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel -import functools -from maxdiffusion import pyconfig -from maxdiffusion.max_utils import ( - create_device_mesh, - setup_initial_state, - get_memory_allocations, -) -from jax.sharding import PartitionSpec as P -import orbax.checkpoint as ocp - -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def load_ref_prediction(): - base_dir = os.path.dirname(__file__) - saved_prediction_path = os.path.join(base_dir, "ltx_vid_transformer_test_ref_pred") - predict_dict = torch.load(saved_prediction_path) - noise_pred_pt = predict_dict["noise_pred"].to(torch.float32) - return noise_pred_pt - - -def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): - latents, state, noise_cond = args - noise_pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - indices_grid=fractional_cords, - encoder_hidden_states=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - return noise_pred, state, noise_cond - - -def run_inference( - states, - transformer, - config, - mesh, - latents, - fractional_cords, - prompt_embeds, - timestep, - segment_ids, - encoder_attention_segment_ids, -): - transformer_state = states["transformer"] - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - fractional_cords=fractional_cords, - prompt_embeds=prompt_embeds, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) - return latents - - -class LTXTransformerTest(unittest.TestCase): - - def test_one_step_transformer(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"), - ], - unittest=True, - ) - config = pyconfig.config - noise_pred_pt = load_ref_prediction() - - # set up transformer - key = jax.random.PRNGKey(42) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") - - with open(config_path, "r") as f: - model_config = json.load(f) - relative_ckpt_path = model_config["ckpt_path"] - ignored_keys = [ - "_class_name", - "_diffusers_version", - "_name_or_path", - "causal_temporal_positioning", - "in_channels", - "ckpt_path", - ] - in_channels = model_config["in_channels"] - for name in ignored_keys: - if name in model_config: - del model_config[name] - - transformer = Transformer3DModel( - **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh - ) - weights_init_fn = functools.partial( - transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True - ) - - absolute_ckpt_path = os.path.abspath(relative_ckpt_path) - - checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - checkpoint_manager=checkpoint_manager, - checkpoint_item=" ", - model_params=None, - training=False, - ) - - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - states["transformer"] = transformer_state - example_inputs = {} - batch_size, num_tokens = 4, 256 - input_shapes = { - "latents": (batch_size, num_tokens, in_channels), - "fractional_coords": (batch_size, 3, num_tokens), - "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), - "timestep": (batch_size, 256), - "segment_ids": (batch_size, 256), - "encoder_attention_segment_ids": (batch_size, 128), - } - for name, shape in input_shapes.items(): - example_inputs[name] = jnp.ones( - shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool - ) - - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(example_inputs["latents"], data_sharding) - prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) - fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) - noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) - segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) - encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - config=config, - mesh=mesh, - latents=latents, - fractional_cords=fractional_coords, - prompt_embeds=prompt_embeds, - timestep=noise_cond, - segment_ids=segment_ids, - encoder_attention_segment_ids=encoder_attention_segment_ids, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - noise_pred = p_run_inference(states).block_until_ready() - noise_pred = torch.from_numpy(np.array(noise_pred)) - - torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20) - - -if __name__ == "__main__": - absltest.main() From 546ecab301c4e321d78bb2464a0df936c55beecb Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 00:20:43 +0000 Subject: [PATCH 20/25] ruff fixed --- src/maxdiffusion/generate_ltx_video.py | 8 +--- src/maxdiffusion/max_utils.py | 23 +---------- src/maxdiffusion/pyconfig.py | 39 +++++++++++-------- .../tests/ltx_transformer_step_test.py | 2 +- 4 files changed, 27 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 371d309e3..fa495ba1a 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -64,10 +64,6 @@ def run_inference( segment_ids=segment_ids, encoder_attention_segment_ids=encoder_attention_segment_ids, ) - prof = profiler.Profiler(config) - prof.activate(optional_postfix="transformer step") - prof.deactivate() - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) @@ -176,8 +172,8 @@ def run(config): in_shardings=(state_shardings,), out_shardings=None, ) - with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): - noise_pred = p_run_inference(states).block_until_ready() + + noise_pred = p_run_inference(states).block_until_ready() print(noise_pred) # (4, 256, 128) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index d4a80a347..9c88a2ac3 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -257,21 +257,6 @@ def create_device_mesh(config, devices=None, logging=True): if devices is None: devices = jax.devices() num_devices = len(devices) - ##special case for ltx-video - if "fsdp_transpose" in config.mesh_axes: - num_slices = 1 - # if config.inference_benchmark_test else config.num_slices - num_devices_per_slice = num_devices // num_slices - # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") - mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) - max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") - - return mesh - try: num_slices = 1 + max([d.slice_index for d in devices]) except: @@ -417,11 +402,7 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - ###!Edited - if checkpoint_item == " ": - state = state - else: - state = state[checkpoint_item] + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -628,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() + jax.distributed.initialize() \ No newline at end of file diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index af6493ea2..edcf96164 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -41,21 +42,6 @@ def string_to_bool(s: str) -> bool: config = None -def create_parallelisms_list(raw_keys): - ici_parallelism = [ - raw_keys["ici_data_parallelism"], - raw_keys["ici_fsdp_parallelism"], - raw_keys["ici_fsdp_transpose_parallelism"], - raw_keys["ici_sequence_parallelism"], - raw_keys["ici_tensor_parallelism"], - raw_keys["ici_tensor_transpose_parallelism"], - raw_keys["ici_expert_parallelism"], - raw_keys["ici_sequence_parallelism"], - ] - raw_keys["ici_parallelism"] = ici_parallelism - return raw_keys - - def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") @@ -117,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -125,6 +112,26 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif ( + transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH + or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH + ): + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -169,8 +176,6 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "ici_fsdp_transpose_parallelism" in raw_keys: - raw_keys = create_parallelisms_list(raw_keys) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 2555b330c..9398c9156 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -191,7 +191,7 @@ def test_one_step_transformer(self): in_shardings=(state_shardings,), out_shardings=None, ) - + noise_pred = p_run_inference(states).block_until_ready() noise_pred = torch.from_numpy(np.array(noise_pred)) From 12a247fe9fa71af862a69644ea4275ec7c7d791d Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 17:39:09 +0000 Subject: [PATCH 21/25] added header --- .../models/ltx_video/transformers/activations.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/adaln.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/attention.py | 16 ++++++++++++++++ .../ltx_video/transformers/caption_projection.py | 16 ++++++++++++++++ .../ltx_video/transformers/transformer3d.py | 16 ++++++++++++++++ 5 files changed, 80 insertions(+) diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py index 8e7ffb321..4ae1d9a00 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/activations.py +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 4bc27e8bc..e9b287649 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Dict, Optional, Tuple import jax diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py index 8a7541263..9faab1ded 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from functools import partial import math from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py index f2b1af101..d8240989c 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from flax import linen as nn import jax.numpy as jnp diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index cf599f26c..d6c7cf4c4 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import List, Optional, Tuple import jax From 1062c72f3f4a9429404154d51d3b9081e87fe762 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:11:17 +0000 Subject: [PATCH 22/25] license headers --- .github/workflows/UnitTests.yml | 2 +- src/maxdiffusion/generate_ltx_video.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/__init__.py | 15 +++++++++++++++ .../models/ltx_video/gradient_checkpoint.py | 16 ++++++++++++++++ src/maxdiffusion/models/ltx_video/linear.py | 16 ++++++++++++++++ .../models/ltx_video/repeatable_layer.py | 16 ++++++++++++++++ .../models/ltx_video/transformers/__init__.py | 15 +++++++++++++++ 7 files changed, 95 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..c1fa771d1 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index fa495ba1a..2dec16fa6 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + from absl import app from typing import Sequence import jax diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py index ef8c530ba..ee7221652 100644 --- a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from enum import Enum, auto from typing import Optional diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py index 31b21cdd9..3503ab3b4 100644 --- a/src/maxdiffusion/models/ltx_video/linear.py +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from typing import Union, Iterable, Tuple, Optional, Callable import numpy as np diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py index aaed41048..7e9cc80c4 100644 --- a/src/maxdiffusion/models/ltx_video/repeatable_layer.py +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -1,3 +1,19 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main from dataclasses import field from typing import Any, Callable, Dict, List, Tuple, Optional diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index e69de29bb..7e4185f36 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ From 535c75eea64044e3b87fce1b4ebcee52b80600a1 Mon Sep 17 00:00:00 2001 From: "serenagu@google.com" Date: Wed, 9 Jul 2025 22:34:49 +0000 Subject: [PATCH 23/25] exclude test --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c1fa771d1..05f332fb7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=maxdiffusion/src/maxdiffusion/tests/ltx_transformer_step_test.py + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: From 64b82c943326b179972a1f25488079c2432ad78c Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:35:46 -0700 Subject: [PATCH 24/25] Update checkpointing_utils.py --- src/maxdiffusion/checkpointing/checkpointing_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..27046a0b6 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,11 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) def map_to_pspec(data): pspec = data.sharding.spec @@ -248,3 +251,4 @@ def map_to_pspec(data): except: max_logging.log(f"could not load {checkpoint_item} from orbax") return None + From 103db8f54f0a7d61cabc27728ee2b79eef2b0b33 Mon Sep 17 00:00:00 2001 From: Serenagu525 <41308432+Serenagu525@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:36:46 -0700 Subject: [PATCH 25/25] Update max_utils.py --- src/maxdiffusion/max_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 9c88a2ac3..e645ecec1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,10 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( @@ -609,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize()