diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index a4f3c752a..1747be8c8 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -58,6 +58,16 @@ data_sharding: ['data', 'fsdp', 'context', 'tensor'] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 + +flash_block_sizes: { + block_q: 2048, + block_kv: 2048, + block_kv_compute: 1024, + block_q_dkv: 2048, + block_kv_dkv: 2048, + block_kv_dkv_compute: 2048, + use_fused_bwd_kernel: True, +} dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index b95ffd610..2ccce7488 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -23,6 +23,7 @@ Array = common_types.Array Mesh = common_types.Mesh DType = common_types.DType +BlockSizes = common_types.BlockSizes def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: @@ -193,9 +194,7 @@ def prepare_video_coords( # pixel_coords[:, 0, ...] selects Frame dimension. # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W) frame_coords = pixel_coords[:, 0, ...] - frame_coords = jnp.clip( - frame_coords + self.causal_offset - self.scale_factors[0], min=0 - ) + frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0) pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps) return pixel_coords @@ -212,16 +211,12 @@ def prepare_audio_coords( # 2. Start timestamps audio_scale_factor = self.scale_factors[0] grid_start_mel = grid_f * audio_scale_factor - grid_start_mel = jnp.clip( - grid_start_mel + self.causal_offset - audio_scale_factor, min=0 - ) + grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0) grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate # 3. End timestamps grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor - grid_end_mel = jnp.clip( - grid_end_mel + self.causal_offset - audio_scale_factor, min=0 - ) + grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0) grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate # Stack [num_patches, 2] @@ -351,6 +346,7 @@ def __init__( dtype: DType = jnp.float32, attention_kernel: str = "flash", rope_type: str = "interleaved", + flash_block_sizes: BlockSizes = None, ): self.heads = heads self.rope_type = rope_type @@ -437,6 +433,7 @@ def __init__( dtype=dtype, axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV), axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV), + flash_block_sizes=flash_block_sizes, ) def __call__( diff --git a/src/maxdiffusion/models/ltx2/transformer_ltx2.py b/src/maxdiffusion/models/ltx2/transformer_ltx2.py index 6a7c2f867..767e58235 100644 --- a/src/maxdiffusion/models/ltx2/transformer_ltx2.py +++ b/src/maxdiffusion/models/ltx2/transformer_ltx2.py @@ -24,6 +24,7 @@ from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.common_types import BlockSizes class LTX2AdaLayerNormSingle(nnx.Module): @@ -105,6 +106,7 @@ def __init__( names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], attention_kernel: str = "flash", + flash_block_sizes: BlockSizes = None, ): self.dim = dim self.norm_eps = norm_eps @@ -134,6 +136,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) self.audio_norm1 = nnx.RMSNorm( @@ -158,6 +161,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) # 2. Prompt Cross-Attention @@ -184,6 +188,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) self.audio_norm2 = nnx.RMSNorm( @@ -209,6 +214,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -235,6 +241,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) self.video_to_audio_norm = nnx.RMSNorm( @@ -260,6 +267,7 @@ def __init__( mesh=mesh, attention_kernel=self.attention_kernel, rope_type=rope_type, + flash_block_sizes=flash_block_sizes, ) # 4. Feed Forward @@ -553,6 +561,7 @@ def __init__( scan_layers: bool = True, attention_kernel: str = "flash", qk_norm: str = "rms_norm_across_heads", + flash_block_sizes: BlockSizes = None, **kwargs, ): self.in_channels = in_channels @@ -791,6 +800,7 @@ def init_block(rngs): names_which_can_be_saved=self.names_which_can_be_saved, names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, + flash_block_sizes=flash_block_sizes, ) if self.scan_layers: @@ -822,6 +832,7 @@ def init_block(rngs): names_which_can_be_saved=self.names_which_can_be_saved, names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, + flash_block_sizes=flash_block_sizes, ) blocks.append(block) self.transformer_blocks = nnx.List(blocks) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 3369f2031..0db6c398a 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -47,7 +47,7 @@ from ...pyconfig import HyperParameters from ... import max_logging from ... import max_utils -from ...max_utils import get_precision, device_put_replicated +from ...max_utils import get_precision, device_put_replicated, get_flash_block_sizes from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs @@ -124,6 +124,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["weights_dtype"] = config.weights_dtype ltx2_config["attention_kernel"] = config.attention ltx2_config["precision"] = get_precision(config) + ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) ltx2_config["remat_policy"] = config.remat_policy ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded