Skip to content

Commit 7bed4f9

Browse files
committed
formatting
1 parent 13656fb commit 7bed4f9

11 files changed

Lines changed: 2023 additions & 2122 deletions

File tree

src/maxdiffusion/__init__.py

Lines changed: 358 additions & 365 deletions
Large diffs are not rendered by default.

src/maxdiffusion/generate_ltx_video.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,67 +14,55 @@
1414
limitations under the License.
1515
"""
1616

17-
1817
from absl import app
1918
from typing import Sequence
2019
import jax
2120
import json
2221
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
2322
import os
24-
import functools
2523
import jax.numpy as jnp
2624
from maxdiffusion import pyconfig
2725
from maxdiffusion.max_utils import (
2826
create_device_mesh,
29-
setup_initial_state,
3027
)
31-
from jax.sharding import Mesh, PartitionSpec as P
3228

3329

3430
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
35-
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
36-
print("fractional_coords.shape: ",
37-
fractional_coords.shape, fractional_coords.dtype)
38-
print("latents.shape: ", latents.shape, latents.dtype)
39-
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
31+
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
32+
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
33+
print("latents.shape: ", latents.shape, latents.dtype)
34+
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
4035

4136

4237
def run(config):
43-
key = jax.random.PRNGKey(0)
38+
key = jax.random.PRNGKey(0)
39+
40+
devices_array = create_device_mesh(config)
41+
mesh = Mesh(devices_array, config.mesh_axes)
42+
43+
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
44+
base_dir = os.path.dirname(__file__)
4445

45-
devices_array = create_device_mesh(config)
46-
mesh = Mesh(devices_array, config.mesh_axes)
46+
# load in model config
47+
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
48+
with open(config_path, "r") as f:
49+
model_config = json.load(f)
4750

48-
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
49-
base_dir = os.path.dirname(__file__)
51+
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
52+
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False)
5053

51-
# load in model config
52-
config_path = os.path.join(
53-
base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
54-
with open(config_path, "r") as f:
55-
model_config = json.load(f)
54+
key, split_key = jax.random.split(key)
5655

57-
transformer = Transformer3DModel(
58-
**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
59-
transformer_param_shapes = transformer.init_weights(
60-
key, batch_size, text_tokens, num_tokens, features, eval_only=False)
6156

62-
key, split_key = jax.random.split(key)
63-
weights_init_fn = functools.partial(
64-
transformer.init_weights,
65-
split_key,
66-
batch_size,
67-
text_tokens,
68-
num_tokens,
69-
features,
70-
eval_only=True
71-
)
57+
weights_init_fn = functools.partial(
58+
transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True
59+
)
7260

7361

7462
def main(argv: Sequence[str]) -> None:
75-
pyconfig.initialize(argv)
76-
run(pyconfig.config)
63+
pyconfig.initialize(argv)
64+
run(pyconfig.config)
7765

7866

7967
if __name__ == "__main__":
80-
app.run(main)
68+
app.run(main)

src/maxdiffusion/models/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525

2626
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2727

28-
from .controlnet_flax import FlaxControlNetModel
29-
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
30-
from .vae_flax import FlaxAutoencoderKL
31-
from .lora import *
32-
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33-
from .ltx_video.transformers.transformer3d import Transformer3DModel
28+
from .controlnet_flax import FlaxControlNetModel
29+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
30+
from .vae_flax import FlaxAutoencoderKL
31+
from .lora import *
32+
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33+
from .ltx_video.transformers.transformer3d import Transformer3DModel
3434

3535
else:
36-
import sys
36+
import sys
3737

38-
sys.modules[__name__] = _LazyModule(
39-
__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
38+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

src/maxdiffusion/models/ltx_video/gradient_checkpoint.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,63 @@
88

99

1010
class GradientCheckpointType(Enum):
11-
"""
12-
Defines the type of the gradient checkpoint we will have
11+
"""
12+
Defines the type of the gradient checkpoint we will have
1313
14-
NONE - means no gradient checkpoint
15-
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
16-
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
17-
except for ones that involve batch dimension - that means that all attention and projection
18-
layers will have gradient checkpoint, but not the backward with respect to the parameters
19-
"""
14+
NONE - means no gradient checkpoint
15+
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
16+
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
17+
except for ones that involve batch dimension - that means that all attention and projection
18+
layers will have gradient checkpoint, but not the backward with respect to the parameters
19+
"""
2020

21-
NONE = auto()
22-
FULL = auto()
23-
MATMUL_WITHOUT_BATCH = auto()
21+
NONE = auto()
22+
FULL = auto()
23+
MATMUL_WITHOUT_BATCH = auto()
2424

25-
@classmethod
26-
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
27-
"""
28-
Constructs the gradient checkpoint type from a string
25+
@classmethod
26+
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
27+
"""
28+
Constructs the gradient checkpoint type from a string
2929
30-
Args:
31-
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
30+
Args:
31+
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
3232
33-
Returns:
34-
GradientCheckpointType: The policy that corresponds to the string
35-
"""
36-
if s is None:
37-
s = "none"
38-
return GradientCheckpointType[s.upper()]
33+
Returns:
34+
GradientCheckpointType: The policy that corresponds to the string
35+
"""
36+
if s is None:
37+
s = "none"
38+
return GradientCheckpointType[s.upper()]
3939

40-
def to_jax_policy(self):
41-
"""
42-
Converts the gradient checkpoint type to a jax policy
43-
"""
44-
match self:
45-
case GradientCheckpointType.NONE:
46-
return SKIP_GRADIENT_CHECKPOINT_KEY
47-
case GradientCheckpointType.FULL:
48-
return None
49-
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
50-
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
40+
def to_jax_policy(self):
41+
"""
42+
Converts the gradient checkpoint type to a jax policy
43+
"""
44+
match self:
45+
case GradientCheckpointType.NONE:
46+
return SKIP_GRADIENT_CHECKPOINT_KEY
47+
case GradientCheckpointType.FULL:
48+
return None
49+
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
50+
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
5151

52-
def apply(self, module: nn.Module) -> nn.Module:
53-
"""
54-
Applies a gradient checkpoint policy to a module
55-
if no policy is needed, it will return the module as is
52+
def apply(self, module: nn.Module) -> nn.Module:
53+
"""
54+
Applies a gradient checkpoint policy to a module
55+
if no policy is needed, it will return the module as is
5656
57-
Args:
58-
module (nn.Module): the module to apply the policy to
57+
Args:
58+
module (nn.Module): the module to apply the policy to
5959
60-
Returns:
61-
nn.Module: the module with the policy applied
62-
"""
63-
policy = self.to_jax_policy()
64-
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
65-
return module
66-
return nn.remat( # pylint: disable=invalid-name
67-
module,
68-
prevent_cse=False,
69-
policy=policy,
70-
)
60+
Returns:
61+
nn.Module: the module with the policy applied
62+
"""
63+
policy = self.to_jax_policy()
64+
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
65+
return module
66+
return nn.remat( # pylint: disable=invalid-name
67+
module,
68+
prevent_cse=False,
69+
policy=policy,
70+
)

0 commit comments

Comments
 (0)