Skip to content

Commit 6552f14

Browse files
author
Serena
committed
model setup
1 parent 0af353d commit 6552f14

17 files changed

Lines changed: 2090 additions & 3 deletions

setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ else
110110
fi
111111

112112
# Install maxdiffusion
113-
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
113+
pip3 install -e . || echo "Failed to install maxdiffusion" >&2

src/maxdiffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@
373373
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
374374
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
375375
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
376+
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
376377
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
377378
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
378379
_import_structure["schedulers"].extend(
@@ -453,6 +454,7 @@
453454
from .models.modeling_flax_utils import FlaxModelMixin
454455
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
455456
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
457+
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
456458
from .models.vae_flax import FlaxAutoencoderKL
457459
from .pipelines import FlaxDiffusionPipeline
458460
from .schedulers import (
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
5+
jax_cache_dir: ''
6+
weights_dtype: 'bfloat16'
7+
activations_dtype: 'bfloat16'
8+
9+
10+
run_name: ''
11+
output_dir: 'ltx-video-output'
12+
save_config_to_gcs: False
13+
14+
#parallelism
15+
mesh_axes: ['data', 'fsdp', 'tensor']
16+
logical_axis_rules: [
17+
['batch', 'data'],
18+
['activation_batch', ['data','fsdp']],
19+
['activation_heads', 'tensor'],
20+
['activation_kv', 'tensor'],
21+
['mlp','tensor'],
22+
['embed','fsdp'],
23+
['heads', 'tensor'],
24+
['conv_batch', ['data','fsdp']],
25+
['out_channels', 'tensor'],
26+
['conv_out', 'fsdp'],
27+
]
28+
data_sharding: [['data', 'fsdp', 'tensor']]
29+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
30+
dcn_fsdp_parallelism: -1
31+
dcn_tensor_parallelism: 1
32+
ici_data_parallelism: -1
33+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
34+
ici_tensor_parallelism: 1
35+
36+
37+
38+
39+
learning_rate_schedule_steps: -1
40+
max_train_steps: 500 #TODO: change this
41+
pretrained_model_name_or_path: ''
42+
unet_checkpoint: ''
43+
dataset_name: 'diffusers/pokemon-gpt4-captions'
44+
train_split: 'train'
45+
dataset_type: 'tf'
46+
cache_latents_text_encoder_outputs: True
47+
per_device_batch_size: 1
48+
compile_topology_num_slices: -1
49+
quantization_local_shard_count: -1
50+
jit_initializers: True
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from absl import app
2+
from typing import Sequence
3+
import jax
4+
import json
5+
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
6+
import os
7+
import functools
8+
import jax.numpy as jnp
9+
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
10+
from maxdiffusion.max_utils import (
11+
create_device_mesh,
12+
setup_initial_state,
13+
)
14+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
15+
16+
17+
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
18+
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
19+
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
20+
print("latents.shape: ", latents.shape, latents.dtype)
21+
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
22+
23+
def run(config):
24+
key = jax.random.PRNGKey(0)
25+
26+
devices_array = create_device_mesh(config)
27+
mesh = Mesh(devices_array, config.mesh_axes)
28+
29+
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
30+
base_dir = os.path.dirname(__file__)
31+
32+
##load in model config
33+
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
34+
with open(config_path, "r") as f:
35+
model_config = json.load(f)
36+
37+
38+
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
39+
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only = False)
40+
41+
key, split_key = jax.random.split(key)
42+
weights_init_fn = functools.partial(
43+
transformer.init_weights,
44+
split_key,
45+
batch_size,
46+
text_tokens,
47+
num_tokens,
48+
features,
49+
eval_only = False
50+
)
51+
52+
transformer_state, transformer_state_shardings = setup_initial_state(
53+
model=transformer,
54+
tx=None,
55+
config=config,
56+
mesh=mesh,
57+
weights_init_fn=weights_init_fn,
58+
model_params=None,
59+
training=False,
60+
)
61+
62+
63+
64+
def main(argv: Sequence[str]) -> None:
65+
pyconfig.initialize(argv)
66+
run(pyconfig.config)
67+
68+
69+
if __name__ == "__main__":
70+
app.run(main)
71+
72+
73+

src/maxdiffusion/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import TYPE_CHECKING
1616

17-
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
17+
from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
1818

1919

2020
_import_structure = {}
@@ -32,7 +32,7 @@
3232
from .vae_flax import FlaxAutoencoderKL
3333
from .lora import *
3434
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
35-
35+
from .ltx_video.transformers.transformer3d import Transformer3DModel
3636
else:
3737
import sys
3838

src/maxdiffusion/models/ltx_video/__init__.py

Whitespace-only changes.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from enum import Enum, auto
2+
from typing import Optional
3+
4+
import jax
5+
from flax import linen as nn
6+
7+
SKIP_GRADIENT_CHECKPOINT_KEY = "skip"
8+
9+
10+
class GradientCheckpointType(Enum):
11+
"""
12+
Defines the type of the gradient checkpoint we will have
13+
14+
NONE - means no gradient checkpoint
15+
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
16+
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
17+
except for ones that involve batch dimension - that means that all attention and projection
18+
layers will have gradient checkpoint, but not the backward with respect to the parameters
19+
"""
20+
21+
NONE = auto()
22+
FULL = auto()
23+
MATMUL_WITHOUT_BATCH = auto()
24+
25+
@classmethod
26+
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
27+
"""
28+
Constructs the gradient checkpoint type from a string
29+
30+
Args:
31+
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
32+
33+
Returns:
34+
GradientCheckpointType: The policy that corresponds to the string
35+
"""
36+
if s is None:
37+
s = "none"
38+
return GradientCheckpointType[s.upper()]
39+
40+
def to_jax_policy(self):
41+
"""
42+
Converts the gradient checkpoint type to a jax policy
43+
"""
44+
match self:
45+
case GradientCheckpointType.NONE:
46+
return SKIP_GRADIENT_CHECKPOINT_KEY
47+
case GradientCheckpointType.FULL:
48+
return None
49+
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
50+
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
51+
52+
def apply(self, module: nn.Module) -> nn.Module:
53+
"""
54+
Applies a gradient checkpoint policy to a module
55+
if no policy is needed, it will return the module as is
56+
57+
Args:
58+
module (nn.Module): the module to apply the policy to
59+
60+
Returns:
61+
nn.Module: the module with the policy applied
62+
"""
63+
policy = self.to_jax_policy()
64+
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
65+
return module
66+
return nn.remat( # pylint: disable=invalid-name
67+
module,
68+
prevent_cse=False,
69+
policy=policy,
70+
)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Union, Iterable, Tuple, Optional, Callable
2+
3+
import numpy as np
4+
import jax
5+
import jax.numpy as jnp
6+
from flax import linen as nn
7+
from flax.linen.initializers import lecun_normal
8+
9+
10+
Shape = Tuple[int, ...]
11+
Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array]
12+
InitializerAxis = Union[int, Shape]
13+
14+
15+
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
16+
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
17+
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
18+
19+
20+
def _canonicalize_tuple(x):
21+
if isinstance(x, Iterable):
22+
return tuple(x)
23+
else:
24+
return (x,)
25+
26+
27+
NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]
28+
KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]
29+
30+
31+
class DenseGeneral(nn.Module):
32+
"""A linear transformation with flexible axes.
33+
34+
Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86
35+
36+
Attributes:
37+
features: tuple with numbers of output features.
38+
axis: tuple with axes to apply the transformation on.
39+
weight_dtype: the dtype of the weights (default: float32).
40+
dtype: the dtype of the computation (default: float32).
41+
kernel_init: initializer function for the weight matrix.
42+
use_bias: whether to add bias in linear transformation.
43+
bias_norm: whether to add normalization before adding bias.
44+
quant: quantization config, defaults to None implying no quantization.
45+
"""
46+
47+
features: Union[Iterable[int], int]
48+
axis: Union[Iterable[int], int] = -1
49+
weight_dtype: jnp.dtype = jnp.float32
50+
dtype: np.dtype = jnp.float32
51+
kernel_init: KernelInitializer = lecun_normal()
52+
kernel_axes: Tuple[Optional[str], ...] = ()
53+
use_bias: bool = False
54+
matmul_precision: str = "default"
55+
56+
bias_init: Initializer = jax.nn.initializers.constant(0.0)
57+
58+
@nn.compact
59+
def __call__(self, inputs: jax.Array) -> jax.Array:
60+
"""Applies a linear transformation to the inputs along multiple dimensions.
61+
62+
Args:
63+
inputs: The nd-array to be transformed.
64+
65+
Returns:
66+
The transformed input.
67+
"""
68+
69+
def compute_dot_general(inputs, kernel, axis, contract_ind):
70+
"""Computes a dot_general operation that may be quantized."""
71+
dot_general = jax.lax.dot_general
72+
matmul_precision = jax.lax.Precision(self.matmul_precision)
73+
return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision)
74+
75+
features = _canonicalize_tuple(self.features)
76+
axis = _canonicalize_tuple(self.axis)
77+
78+
inputs = jnp.asarray(inputs, self.dtype)
79+
axis = _normalize_axes(axis, inputs.ndim)
80+
81+
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
82+
kernel_in_axis = np.arange(len(axis))
83+
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
84+
kernel = self.param(
85+
"kernel",
86+
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
87+
kernel_shape,
88+
self.weight_dtype,
89+
)
90+
kernel = jnp.asarray(kernel, self.dtype)
91+
92+
contract_ind = tuple(range(0, len(axis)))
93+
output = compute_dot_general(inputs, kernel, axis, contract_ind)
94+
95+
if self.use_bias:
96+
bias_axes, bias_shape = (
97+
self.kernel_axes[-len(features) :],
98+
kernel_shape[-len(features) :],
99+
)
100+
bias = self.param(
101+
"bias",
102+
nn.with_logical_partitioning(self.bias_init, bias_axes),
103+
bias_shape,
104+
self.weight_dtype,
105+
)
106+
bias = jnp.asarray(bias, self.dtype)
107+
108+
output += bias
109+
return output
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import jax
3+
import jax.numpy as jnp
4+
import json
5+
6+
7+
from models.transformers.transformer3d import Transformer3DModel
8+
9+
# Load JSON config
10+
base_dir = os.path.dirname(__file__)
11+
config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json")
12+
with open(config_path, "r") as f:
13+
model_config = json.load(f)
14+
15+
key = jax.random.PRNGKey(0)
16+
model = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
17+
18+
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
19+
prompt_embeds = jax.random.normal(key, shape=(batch_size, text_tokens, features), dtype=jnp.bfloat16)
20+
fractional_coords = jax.random.normal(key, shape=(batch_size, 3, num_tokens), dtype=jnp.bfloat16)
21+
latents = jax.random.normal(key, shape=(batch_size, num_tokens, features), dtype=jnp.bfloat16)
22+
noise_cond = jax.random.normal(key, shape=(batch_size, 1), dtype=jnp.bfloat16)
23+
24+
model_params = model.init(
25+
hidden_states=latents,
26+
indices_grid=fractional_coords,
27+
encoder_hidden_states=prompt_embeds,
28+
timestep=noise_cond,
29+
rngs={"params": key}
30+
)
31+
32+
output = model.apply(
33+
model_params,
34+
hidden_states=latents,
35+
indices_grid=fractional_coords,
36+
encoder_hidden_states=prompt_embeds,
37+
timestep=noise_cond,
38+
)
39+
40+
print("done!")

0 commit comments

Comments
 (0)