|
| 1 | +# Design Doc: Centralized and Configuration-Driven Sharding Strategy |
| 2 | + |
| 3 | +## Objective |
| 4 | +To centralize the sharding logic in MaxDiffusion, enabling hardware-specific optimizations (e.g., for TPU v6e vs v7x) without hardcoding checks in model layers or polluting constructors with sharding parameters. |
| 5 | + |
| 6 | +## Background |
| 7 | +Currently, sharding specifications are often hardcoded within model layers or determined by ad-hoc hardware checks (e.g., checking `jax.devices()[0].device_kind`). This makes the code: |
| 8 | +- **Hard to maintain and extend**: Adding support for new hardware requires modifying multiple files. |
| 9 | +- **Difficult to test**: It's hard to test different sharding strategies on the same hardware for debugging or benchmarking. |
| 10 | +- **Cluttered**: Model definition code is mixed with hardware-specific execution policies. |
| 11 | + |
| 12 | +In the `prisha/ltx2_opt` branch, we see initial attempts to address this by abstracting TPU type detection, but the sharding specs themselves are still hardcoded based on the detected hardware in [attention_ltx2.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/models/ltx2/attention_ltx2.py). |
| 13 | + |
| 14 | +### Proposed Design |
| 15 | + |
| 16 | +We propose a design that combines **Discrete Logical Rulesets** and **Explicit Parameter Passing at the Top Level** to achieve a clean separation of concerns while adhering to JAX and Flax NNX best practices. |
| 17 | + |
| 18 | +### 1. Configuration |
| 19 | +We will add a `sharding` section to the YAML configuration files, allowing independent overrides for different model components (e.g., Transformer, VAE). |
| 20 | + |
| 21 | +Example in `ltx2_video.yml`: |
| 22 | +```yaml |
| 23 | +sharding: |
| 24 | + transformer: 'ironwood' |
| 25 | + vae: 'default' |
| 26 | + text_encoder: 'default' |
| 27 | +``` |
| 28 | +
|
| 29 | +#### Auto-Detection & Backward Compatibility |
| 30 | +To improve usability and ensure backward compatibility: |
| 31 | +- **Auto-Detection**: Specifying the sharding strategy is **optional**. If omitted (or if a legacy config file lacks the `sharding` block), `pyconfig.py` will auto-detect the TPU hardware generation at startup and set the strategy to the optimal default for that chip (e.g., `'ironwood'` for v7x). |
| 32 | +- **Logging**: The resolved strategy will be explicitly logged to maintain transparency. |
| 33 | +- **Overrides**: Users can always override this auto-detection by explicitly setting the strategy in the YAML file or via CLI. |
| 34 | + |
| 35 | + |
| 36 | +### 2. Discrete Logical Rulesets (Model-Specific File) |
| 37 | +To keep the code simple and avoid file clutter, we organize the sharding specs into a single file per model, located in the model's directory. This keeps the sharding logic close to the model code for better readability by model developers. |
| 38 | + |
| 39 | +For LTX2, this file will be `src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py`. |
| 40 | + |
| 41 | +This file will contain the discrete specs, the registry, and the factory function: |
| 42 | + |
| 43 | +```python |
| 44 | +from dataclasses import dataclass |
| 45 | +from typing import Any, Optional |
| 46 | +
|
| 47 | +
|
| 48 | +# --- Discrete Specs --- |
| 49 | +@dataclass |
| 50 | +class LTX2DiTShardingSpecs: |
| 51 | + """Sharding specs for the LTX2 Diffusion Transformer.""" |
| 52 | +
|
| 53 | + qkv_kernel: tuple |
| 54 | + out_kernel: tuple |
| 55 | + out_bias: tuple |
| 56 | + norm_scale: tuple = ("norm",) |
| 57 | + embed_bias: tuple = ("embed",) |
| 58 | +
|
| 59 | +
|
| 60 | +@dataclass |
| 61 | +class TextEncoderShardingSpecs: |
| 62 | + """Specs for the Text Encoder execution.""" |
| 63 | +
|
| 64 | + use_batched_text_encoder: bool = False |
| 65 | + text_encoder_kernel: Optional[tuple] = None |
| 66 | +
|
| 67 | +
|
| 68 | +@dataclass |
| 69 | +class VAEShardingSpecs: |
| 70 | + """Sharding specs for the VAE.""" |
| 71 | +
|
| 72 | + vae_conv_kernel: Optional[tuple] = None |
| 73 | +
|
| 74 | +
|
| 75 | +# --- Unified Registry for LTX2 --- |
| 76 | +STRATEGIES = { |
| 77 | + "ironwood": { |
| 78 | + "ltx2_dit": LTX2DiTShardingSpecs( |
| 79 | + qkv_kernel=(None, "heads"), |
| 80 | + out_kernel=("heads", None), |
| 81 | + out_bias=(None,), |
| 82 | + ), |
| 83 | + "text_encoder": TextEncoderShardingSpecs( |
| 84 | + use_batched_text_encoder=True, |
| 85 | + text_encoder_kernel=(None, "embed"), |
| 86 | + ), |
| 87 | + "vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)), |
| 88 | + }, |
| 89 | + "trillium": { |
| 90 | + "ltx2_dit": LTX2DiTShardingSpecs( |
| 91 | + qkv_kernel=("embed", "heads"), |
| 92 | + out_kernel=("heads", "embed"), |
| 93 | + out_bias=("embed",), |
| 94 | + ), |
| 95 | + "text_encoder": TextEncoderShardingSpecs( |
| 96 | + use_batched_text_encoder=False, |
| 97 | + text_encoder_kernel=(None, "embed"), |
| 98 | + ), |
| 99 | + "vae": VAEShardingSpecs(vae_conv_kernel=(None, None, None, None)), |
| 100 | + }, |
| 101 | +} |
| 102 | +
|
| 103 | +
|
| 104 | +def get_sharding_specs(strategy_name: str, component_name: str) -> Any: |
| 105 | + """Unified factory to get specs for any component.""" |
| 106 | + hardware_profile = STRATEGIES.get(strategy_name, STRATEGIES["trillium"]) |
| 107 | + specs = hardware_profile.get(component_name) |
| 108 | + if specs is None: |
| 109 | + raise ValueError(f"Component {component_name} not found in strategy {strategy_name}") |
| 110 | + return specs |
| 111 | +``` |
| 112 | + |
| 113 | + |
| 114 | +### 3. Application (Unpacking at the Top Level) |
| 115 | + |
| 116 | +To avoid coupling low-level layers to model-specific strategy objects, the top-level model (e.g., `LTX2VideoTransformer3DModel`) will accept the specs object, but will **unpack** it and pass only the specific tuples or `PartitionSpec`s down to the leaf nodes (like `LTX2Attention`). |
| 117 | + |
| 118 | +#### In the Pipeline |
| 119 | +The pipeline file (e.g., `ltx2_pipeline.py`) reads the strategy name from the config, retrieves the specific specs object for each component, and passes it to the respective top-level model. |
| 120 | + |
| 121 | +```python |
| 122 | +# 1. Read component-specific strategy names from config |
| 123 | +sharding_config = getattr(self.config, "sharding", {}) |
| 124 | +transformer_strategy = sharding_config.get("transformer", "default") |
| 125 | +te_strategy = sharding_config.get("text_encoder", "default") |
| 126 | +
|
| 127 | +# 2. Get the specific specs for components |
| 128 | +dit_specs = get_sharding_specs(transformer_strategy, "ltx2_dit") |
| 129 | +te_specs = get_sharding_specs(te_strategy, "text_encoder") |
| 130 | +
|
| 131 | +# 3. Use for pipeline execution choices |
| 132 | +if te_specs.use_batched_text_encoder: |
| 133 | + # ... |
| 134 | +
|
| 135 | +# 4. Pass to the top-level model |
| 136 | +self.transformer = LTX2VideoTransformer3DModel( |
| 137 | + # ... |
| 138 | + sharding_specs=dit_specs, |
| 139 | +) |
| 140 | +``` |
| 141 | + |
| 142 | +#### In Model Layers |
| 143 | +The top-level model receives the specs object and unpacks it for its children. |
| 144 | + |
| 145 | +Example in `LTX2VideoTransformer3DModel`: |
| 146 | +```python |
| 147 | +class LTX2VideoTransformer3DModel(nnx.Module): |
| 148 | +
|
| 149 | + def __init__(self, ..., sharding_specs: LTX2DiTShardingSpecs): |
| 150 | + # Unpack and pass specific tuples to blocks |
| 151 | + self.block = LTX2VideoTransformerBlock( |
| 152 | + ..., |
| 153 | + qkv_sharding_spec=sharding_specs.qkv_kernel, |
| 154 | + out_sharding_spec=sharding_specs.out_kernel, |
| 155 | + out_bias_sharding_spec=sharding_specs.out_bias, |
| 156 | + ) |
| 157 | +``` |
| 158 | + |
| 159 | +Example in `LTX2Attention` (Leaf Node): |
| 160 | +```python |
| 161 | +class LTX2Attention(nnx.Module): |
| 162 | +
|
| 163 | + def __init__( |
| 164 | + self, |
| 165 | + ..., |
| 166 | + qkv_sharding_spec: tuple, |
| 167 | + out_sharding_spec: tuple, |
| 168 | + out_bias_sharding_spec: tuple, |
| 169 | + ): |
| 170 | + # Use the specific tuples directly, completely agnostic to the parent strategy |
| 171 | + self.qkv_sharding_spec = qkv_sharding_spec |
| 172 | + # ... |
| 173 | +``` |
| 174 | + |
| 175 | +### 4. Logical-to-Physical Mesh Mapping |
| 176 | +Logical axis names like `"heads"` and `"embed"` must be bound to a physical JAX Mesh. |
| 177 | + |
| 178 | +In MaxDiffusion, this mapping is handled at the top level via `logical_axis_rules` (typically defined in the YAML config file). These rules map logical axis names to physical mesh axes (e.g., `"data"`, `"model"`, `"fsdp"`). |
| 179 | + |
| 180 | +Different TPU topologies (v6e vs v7x) have different optimal physical mesh dimensions. We handle this by selecting the appropriate config file via the CLI, or by overriding the `logical_axis_rules` directly from the CLI. |
| 181 | + |
| 182 | +Example of overriding `logical_axis_rules` directly via CLI: |
| 183 | + |
| 184 | +```bash |
| 185 | +python src/maxdiffusion/generate_ltx2.py src/maxdiffusion/configs/ltx2_video.yml logical_axis_rules="[('heads', 'model'), ('embed', 'data')]" |
| 186 | +``` |
| 187 | + |
| 188 | +### 5. Startup Validation |
| 189 | +To ensure that the configuration and code are in sync, we propose adding a validation step at startup (e.g., in the pipeline or `pyconfig.py`). |
| 190 | + |
| 191 | +**Problem**: If a logical sharding spec uses an axis name (e.g., `"heads"`) that is not defined in the active `logical_axis_rules`, JAX might fail late or silently fall back to suboptimal sharding. |
| 192 | + |
| 193 | +**Solution**: |
| 194 | +1. Collect all logical axis names used in the active sharding strategies. |
| 195 | +2. Cross-reference them with the keys in `logical_axis_rules`. |
| 196 | +3. If any logical axis name is missing from `logical_axis_rules`, raise a `ValueError` to fail fast. |
| 197 | +4. Allow users to bypass this check with a `--skip_sharding_validation` flag if they explicitly want to proceed with potential defaults. |
| 198 | + |
| 199 | +## Performance Considerations |
| 200 | +- This is purely a code-structuring change and does not introduce any runtime overhead. |
| 201 | +- The specs returned by the factory are static strings or tuples of strings, which are perfectly traced by JAX and compiled by XLA. |
| 202 | + |
| 203 | +## Alternatives Considered |
| 204 | + |
| 205 | +### 1. Hardcoded Hardware Checks |
| 206 | +Checking `device_kind` directly in the model components. |
| 207 | +- **Why rejected**: Scattered checks make the code hard to maintain, extend, and test. |
| 208 | + |
| 209 | +### 2. Excessive Configuration/Plumbing (Pure YAML or Individual Constructor Arguments) |
| 210 | +Putting all specs in YAML or passing every spec individually through all constructors. |
| 211 | +- **Why rejected**: Leads to either bloated configuration files or polluted constructors in intermediate layers. We struck a balance by using dataclasses at the top level and unpacking them for leaf nodes. |
| 212 | + |
| 213 | +### 3. Monolithic or Global State Objects |
| 214 | +Using a class-based strategy per hardware or a global singleton manager. |
| 215 | +- **Why rejected**: Leads to class explosion or violates JAX functional purity principles by introducing global state. |
| 216 | + |
| 217 | +## Prototype Plan: LTX2 |
| 218 | +We will use the LTX2 model as a prototype to validate this design. |
| 219 | + |
| 220 | +1. **Create** the `src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py` file with the specs and factory. |
| 221 | +2. **Update** [ltx2_pipeline.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py) to read the config and get the strategy. |
| 222 | +3. **Update** `transformer_ltx2.py` and `attention_ltx2.py` to accept and use the strategy object. |
| 223 | +4. **Verify** by: |
| 224 | + * Adding unit tests for the factory and strategy objects. |
| 225 | + * Running existing LTX2 integration tests with both `ironwood` and `trillium` strategies to ensure no regressions. |
| 226 | + |
| 227 | +## Shared Components (e.g., `attention_flax.py`) |
| 228 | +For components shared across different models (like `NNXSimpleFeedForward` in [attention_flax.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/models/attention_flax.py)), we will pass the specific sharding specs as arguments to their constructors, and the LTX2-specific caller will fetch those values from the respective specs object. |
| 229 | + |
| 230 | +## Future Expansion |
| 231 | +If the prototype succeeds on LTX2, we plan to expand this pattern to other models like **WAN** and **Flux** by adding corresponding strategies and factories. |
| 232 | + |
| 233 | + |
0 commit comments