Skip to content

Commit 98d4fcb

Browse files
committed
feat: add LTX2 smoke test and fix pipeline state sharding
1 parent c98002f commit 98d4fcb

4 files changed

Lines changed: 358 additions & 16 deletions

File tree

docs/sharding_strategy_design.md

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Design Doc: Centralized and Configuration-Driven Sharding Strategy
2+
3+
## Objective
4+
To centralize the sharding logic in MaxDiffusion, enabling hardware-specific optimizations (e.g., for TPU v6e vs v7x) without hardcoding checks in model layers or polluting constructors with sharding parameters.
5+
6+
## Background
7+
Currently, sharding specifications are often hardcoded within model layers or determined by ad-hoc hardware checks (e.g., checking `jax.devices()[0].device_kind`). This makes the code:
8+
- **Hard to maintain and extend**: Adding support for new hardware requires modifying multiple files.
9+
- **Difficult to test**: It's hard to test different sharding strategies on the same hardware for debugging or benchmarking.
10+
- **Cluttered**: Model definition code is mixed with hardware-specific execution policies.
11+
12+
In the `prisha/ltx2_opt` branch, we see initial attempts to address this by abstracting TPU type detection, but the sharding specs themselves are still hardcoded based on the detected hardware in [attention_ltx2.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/models/ltx2/attention_ltx2.py).
13+
14+
### Proposed Design
15+
16+
We propose a design that combines **Discrete Logical Rulesets** and **Explicit Parameter Passing at the Top Level** to achieve a clean separation of concerns while adhering to JAX and Flax NNX best practices.
17+
18+
### 1. Configuration
19+
We will add a `sharding` section to the YAML configuration files, allowing independent overrides for different model components (e.g., Transformer, VAE).
20+
21+
Example in `ltx2_video.yml`:
22+
```yaml
23+
sharding:
24+
transformer: 'ironwood'
25+
vae: 'default'
26+
text_encoder: 'default'
27+
```
28+
29+
#### Auto-Detection & Backward Compatibility
30+
To improve usability and ensure backward compatibility:
31+
- **Auto-Detection**: Specifying the sharding strategy is **optional**. If omitted (or if a legacy config file lacks the `sharding` block), `pyconfig.py` will auto-detect the TPU hardware generation at startup and set the strategy to the optimal default for that chip (e.g., `'ironwood'` for v7x).
32+
- **Logging**: The resolved strategy will be explicitly logged to maintain transparency.
33+
- **Overrides**: Users can always override this auto-detection by explicitly setting the strategy in the YAML file or via CLI.
34+
35+
36+
### 2. Discrete Logical Rulesets (Model-Specific File)
37+
To keep the code simple and avoid file clutter, we organize the sharding specs into a single file per model, located in the model's directory. This keeps the sharding logic close to the model code for better readability by model developers.
38+
39+
For LTX2, this file will be `src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py`.
40+
41+
This file will contain the discrete specs, the registry, and the factory function:
42+
43+
```python
44+
from dataclasses import dataclass
45+
from typing import Any, Optional
46+
47+
48+
# --- Discrete Specs ---
49+
@dataclass
50+
class LTX2DiTShardingSpecs:
51+
"""Sharding specs for the LTX2 Diffusion Transformer."""
52+
53+
qkv_kernel: tuple
54+
out_kernel: tuple
55+
out_bias: tuple
56+
norm_scale: tuple = ("norm",)
57+
embed_bias: tuple = ("embed",)
58+
59+
60+
@dataclass
61+
class TextEncoderShardingSpecs:
62+
"""Specs for the Text Encoder execution."""
63+
64+
use_batched_text_encoder: bool = False
65+
text_encoder_kernel: Optional[tuple] = None
66+
67+
68+
@dataclass
69+
class VAEShardingSpecs:
70+
"""Sharding specs for the VAE."""
71+
72+
vae_conv_kernel: Optional[tuple] = None
73+
74+
75+
# --- Unified Registry for LTX2 ---
76+
STRATEGIES = {
77+
"ironwood": {
78+
"ltx2_dit": LTX2DiTShardingSpecs(
79+
qkv_kernel=(None, "heads"),
80+
out_kernel=("heads", None),
81+
out_bias=(None,),
82+
),
83+
"text_encoder": TextEncoderShardingSpecs(
84+
use_batched_text_encoder=True,
85+
text_encoder_kernel=(None, "embed"),
86+
),
87+
"vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)),
88+
},
89+
"trillium": {
90+
"ltx2_dit": LTX2DiTShardingSpecs(
91+
qkv_kernel=("embed", "heads"),
92+
out_kernel=("heads", "embed"),
93+
out_bias=("embed",),
94+
),
95+
"text_encoder": TextEncoderShardingSpecs(
96+
use_batched_text_encoder=False,
97+
text_encoder_kernel=(None, "embed"),
98+
),
99+
"vae": VAEShardingSpecs(vae_conv_kernel=(None, None, None, None)),
100+
},
101+
}
102+
103+
104+
def get_sharding_specs(strategy_name: str, component_name: str) -> Any:
105+
"""Unified factory to get specs for any component."""
106+
hardware_profile = STRATEGIES.get(strategy_name, STRATEGIES["trillium"])
107+
specs = hardware_profile.get(component_name)
108+
if specs is None:
109+
raise ValueError(f"Component {component_name} not found in strategy {strategy_name}")
110+
return specs
111+
```
112+
113+
114+
### 3. Application (Unpacking at the Top Level)
115+
116+
To avoid coupling low-level layers to model-specific strategy objects, the top-level model (e.g., `LTX2VideoTransformer3DModel`) will accept the specs object, but will **unpack** it and pass only the specific tuples or `PartitionSpec`s down to the leaf nodes (like `LTX2Attention`).
117+
118+
#### In the Pipeline
119+
The pipeline file (e.g., `ltx2_pipeline.py`) reads the strategy name from the config, retrieves the specific specs object for each component, and passes it to the respective top-level model.
120+
121+
```python
122+
# 1. Read component-specific strategy names from config
123+
sharding_config = getattr(self.config, "sharding", {})
124+
transformer_strategy = sharding_config.get("transformer", "default")
125+
te_strategy = sharding_config.get("text_encoder", "default")
126+
127+
# 2. Get the specific specs for components
128+
dit_specs = get_sharding_specs(transformer_strategy, "ltx2_dit")
129+
te_specs = get_sharding_specs(te_strategy, "text_encoder")
130+
131+
# 3. Use for pipeline execution choices
132+
if te_specs.use_batched_text_encoder:
133+
# ...
134+
135+
# 4. Pass to the top-level model
136+
self.transformer = LTX2VideoTransformer3DModel(
137+
# ...
138+
sharding_specs=dit_specs,
139+
)
140+
```
141+
142+
#### In Model Layers
143+
The top-level model receives the specs object and unpacks it for its children.
144+
145+
Example in `LTX2VideoTransformer3DModel`:
146+
```python
147+
class LTX2VideoTransformer3DModel(nnx.Module):
148+
149+
def __init__(self, ..., sharding_specs: LTX2DiTShardingSpecs):
150+
# Unpack and pass specific tuples to blocks
151+
self.block = LTX2VideoTransformerBlock(
152+
...,
153+
qkv_sharding_spec=sharding_specs.qkv_kernel,
154+
out_sharding_spec=sharding_specs.out_kernel,
155+
out_bias_sharding_spec=sharding_specs.out_bias,
156+
)
157+
```
158+
159+
Example in `LTX2Attention` (Leaf Node):
160+
```python
161+
class LTX2Attention(nnx.Module):
162+
163+
def __init__(
164+
self,
165+
...,
166+
qkv_sharding_spec: tuple,
167+
out_sharding_spec: tuple,
168+
out_bias_sharding_spec: tuple,
169+
):
170+
# Use the specific tuples directly, completely agnostic to the parent strategy
171+
self.qkv_sharding_spec = qkv_sharding_spec
172+
# ...
173+
```
174+
175+
### 4. Logical-to-Physical Mesh Mapping
176+
Logical axis names like `"heads"` and `"embed"` must be bound to a physical JAX Mesh.
177+
178+
In MaxDiffusion, this mapping is handled at the top level via `logical_axis_rules` (typically defined in the YAML config file). These rules map logical axis names to physical mesh axes (e.g., `"data"`, `"model"`, `"fsdp"`).
179+
180+
Different TPU topologies (v6e vs v7x) have different optimal physical mesh dimensions. We handle this by selecting the appropriate config file via the CLI, or by overriding the `logical_axis_rules` directly from the CLI.
181+
182+
Example of overriding `logical_axis_rules` directly via CLI:
183+
184+
```bash
185+
python src/maxdiffusion/generate_ltx2.py src/maxdiffusion/configs/ltx2_video.yml logical_axis_rules="[('heads', 'model'), ('embed', 'data')]"
186+
```
187+
188+
### 5. Startup Validation
189+
To ensure that the configuration and code are in sync, we propose adding a validation step at startup (e.g., in the pipeline or `pyconfig.py`).
190+
191+
**Problem**: If a logical sharding spec uses an axis name (e.g., `"heads"`) that is not defined in the active `logical_axis_rules`, JAX might fail late or silently fall back to suboptimal sharding.
192+
193+
**Solution**:
194+
1. Collect all logical axis names used in the active sharding strategies.
195+
2. Cross-reference them with the keys in `logical_axis_rules`.
196+
3. If any logical axis name is missing from `logical_axis_rules`, raise a `ValueError` to fail fast.
197+
4. Allow users to bypass this check with a `--skip_sharding_validation` flag if they explicitly want to proceed with potential defaults.
198+
199+
## Performance Considerations
200+
- This is purely a code-structuring change and does not introduce any runtime overhead.
201+
- The specs returned by the factory are static strings or tuples of strings, which are perfectly traced by JAX and compiled by XLA.
202+
203+
## Alternatives Considered
204+
205+
### 1. Hardcoded Hardware Checks
206+
Checking `device_kind` directly in the model components.
207+
- **Why rejected**: Scattered checks make the code hard to maintain, extend, and test.
208+
209+
### 2. Excessive Configuration/Plumbing (Pure YAML or Individual Constructor Arguments)
210+
Putting all specs in YAML or passing every spec individually through all constructors.
211+
- **Why rejected**: Leads to either bloated configuration files or polluted constructors in intermediate layers. We struck a balance by using dataclasses at the top level and unpacking them for leaf nodes.
212+
213+
### 3. Monolithic or Global State Objects
214+
Using a class-based strategy per hardware or a global singleton manager.
215+
- **Why rejected**: Leads to class explosion or violates JAX functional purity principles by introducing global state.
216+
217+
## Prototype Plan: LTX2
218+
We will use the LTX2 model as a prototype to validate this design.
219+
220+
1. **Create** the `src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py` file with the specs and factory.
221+
2. **Update** [ltx2_pipeline.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py) to read the config and get the strategy.
222+
3. **Update** `transformer_ltx2.py` and `attention_ltx2.py` to accept and use the strategy object.
223+
4. **Verify** by:
224+
* Adding unit tests for the factory and strategy objects.
225+
* Running existing LTX2 integration tests with both `ironwood` and `trillium` strategies to ensure no regressions.
226+
227+
## Shared Components (e.g., `attention_flax.py`)
228+
For components shared across different models (like `NNXSimpleFeedForward` in [attention_flax.py](https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/models/attention_flax.py)), we will pass the specific sharding specs as arguments to their constructors, and the LTX2-specific caller will fetch those values from the respective specs object.
229+
230+
## Future Expansion
231+
If the prototype succeeds on LTX2, we plan to expand this pattern to other models like **WAN** and **Flux** by adding corresponding strategies and factories.
232+
233+

setup.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
# Enable "exit immediately if any command fails" option
2323
set -e
2424
export DEBIAN_FRONTEND=noninteractive
25+
export PIP_INDEX_URL=https://pypi.org/simple
26+
export UV_INDEX_URL=https://pypi.org/simple
2527

2628
echo "Checking Python version..."
2729
# This command will fail if the Python version is less than 3.12
@@ -106,8 +108,13 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
106108
exit 1
107109
fi
108110

109-
# Set uv to use system python by default
110-
export UV_SYSTEM_PYTHON=1
111+
# Set uv to use system python if not in a virtual environment
112+
if python3 -c 'import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)'; then
113+
echo "Virtual environment detected. UV will use it."
114+
else
115+
echo "System Python detected. Setting UV_SYSTEM_PYTHON=1."
116+
export UV_SYSTEM_PYTHON=1
117+
fi
111118

112119
# Install dependencies from requirements.txt first
113120
python3 -m uv pip install -U --resolution=lowest \

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
170170
for path, val in flax.traverse_util.flatten_dict(params).items():
171171
if restored_checkpoint:
172172
path = path[:-1]
173-
sharding = logical_state_sharding[path].value
174-
state[path].value = device_put_replicated(val, sharding)
173+
sharding = logical_state_sharding[path].get_value()
174+
state[path].set_value(device_put_replicated(val, sharding))
175175
state = nnx.from_flat_state(state)
176176

177177
transformer = nnx.merge(graphdef, state, rest_of_state)
@@ -351,10 +351,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
351351
for path, val in flax.traverse_util.flatten_dict(params).items():
352352
sharding = logical_state_sharding.get(path)
353353
if sharding is not None:
354-
sharding = sharding.value
355-
state[path].value = device_put_replicated(val, sharding)
354+
sharding = sharding.get_value()
355+
state[path].set_value(device_put_replicated(val, sharding))
356356
else:
357-
state[path].value = jax.device_put(val)
357+
state[path].set_value(jax.device_put(val))
358358

359359
state = nnx.from_flat_state(state)
360360
connectors = nnx.merge(graphdef, state, rest_of_state)
@@ -393,16 +393,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
393393
for path, val in flax.traverse_util.flatten_dict(params).items():
394394
sharding = logical_state_sharding.get(path)
395395
if sharding is not None:
396-
sharding = sharding.value
396+
sharding = sharding.get_value()
397397
try:
398398
replicate_vae = config.replicate_vae
399399
except ValueError:
400400
replicate_vae = False
401401
if replicate_vae:
402402
sharding = NamedSharding(mesh, P())
403-
state[path].value = device_put_replicated(val, sharding)
403+
state[path].set_value(device_put_replicated(val, sharding))
404404
else:
405-
state[path].value = jax.device_put(val)
405+
state[path].set_value(jax.device_put(val))
406406

407407
state = nnx.from_flat_state(state)
408408
vae = nnx.merge(graphdef, state, rest_of_state)
@@ -441,16 +441,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
441441
for path, val in flax.traverse_util.flatten_dict(params).items():
442442
sharding = logical_state_sharding.get(path)
443443
if sharding is not None:
444-
sharding = sharding.value
444+
sharding = sharding.get_value()
445445
try:
446446
replicate_vae = config.replicate_vae
447447
except ValueError:
448448
replicate_vae = False
449449
if replicate_vae:
450450
sharding = NamedSharding(mesh, P())
451-
state[path].value = device_put_replicated(val, sharding)
451+
state[path].set_value(device_put_replicated(val, sharding))
452452
else:
453-
state[path].value = jax.device_put(val)
453+
state[path].set_value(jax.device_put(val))
454454

455455
state = nnx.from_flat_state(state)
456456
audio_vae = nnx.merge(graphdef, state, rest_of_state)
@@ -510,10 +510,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
510510
for path, val in flax.traverse_util.flatten_dict(params).items():
511511
sharding = logical_state_sharding.get(path)
512512
if sharding is not None:
513-
sharding = sharding.value
514-
state[path].value = device_put_replicated(val, sharding)
513+
sharding = sharding.get_value()
514+
state[path].set_value(device_put_replicated(val, sharding))
515515
else:
516-
state[path].value = jax.device_put(val)
516+
state[path].set_value(jax.device_put(val))
517517

518518
state = nnx.from_flat_state(state)
519519
vocoder = nnx.merge(graphdef, state, rest_of_state)

0 commit comments

Comments
 (0)