Skip to content

Commit 4e362c3

Browse files
committed
merge and lint
1 parent dc10581 commit 4e362c3

8 files changed

Lines changed: 143 additions & 126 deletions

File tree

src/maxdiffusion/configuration_utils.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,24 @@
4747

4848
_re_configuration_file = re.compile(r"config\.(.*)\.json")
4949

50+
5051
class CustomEncoder(json.JSONEncoder):
51-
"""
52-
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
53-
"""
54-
def default(self, o):
55-
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
56-
if isinstance(o, type(jnp.dtype('bfloat16'))):
57-
return str(o)
58-
# Add fallbacks for other numpy types if needed
59-
if isinstance(o, np.integer):
60-
return int(o)
61-
if isinstance(o, np.floating):
62-
return float(o)
63-
# Let the base class default method raise the TypeError for other types
64-
return super().default(o)
52+
"""
53+
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
54+
"""
55+
56+
def default(self, o):
57+
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
58+
if isinstance(o, type(jnp.dtype("bfloat16"))):
59+
return str(o)
60+
# Add fallbacks for other numpy types if needed
61+
if isinstance(o, np.integer):
62+
return int(o)
63+
if isinstance(o, np.floating):
64+
return float(o)
65+
# Let the base class default method raise the TypeError for other types
66+
return super().default(o)
67+
6568

6669
class FrozenDict(OrderedDict):
6770

@@ -596,14 +599,14 @@ def to_json_saveable(value):
596599
config_dict.pop("quant", None)
597600
keys_to_remove = []
598601
for key, value in config_dict.items():
599-
# Check the type of the value by its class name to avoid import issues
600-
if type(value).__name__ == 'Rngs':
601-
keys_to_remove.append(key)
602+
# Check the type of the value by its class name to avoid import issues
603+
if type(value).__name__ == "Rngs":
604+
keys_to_remove.append(key)
602605

603606
if keys_to_remove:
604-
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
605-
for key in keys_to_remove:
606-
config_dict.pop(key)
607+
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
608+
for key in keys_to_remove:
609+
config_dict.pop(key)
607610

608611
try:
609612

src/maxdiffusion/generate_wan.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,47 @@
2222
from maxdiffusion.utils import export_to_video
2323
from google.cloud import storage
2424

25+
2526
def upload_video_to_gcs(output_dir: str, video_path: str):
26-
"""
27-
Uploads a local video file to a specified Google Cloud Storage bucket.
28-
"""
29-
try:
30-
path_without_scheme = output_dir.removeprefix("gs://")
31-
parts = path_without_scheme.split('/', 1)
32-
bucket_name = parts[0]
33-
folder_name = parts[1] if len(parts) > 1 else ''
27+
"""
28+
Uploads a local video file to a specified Google Cloud Storage bucket.
29+
"""
30+
try:
31+
path_without_scheme = output_dir.removeprefix("gs://")
32+
parts = path_without_scheme.split("/", 1)
33+
bucket_name = parts[0]
34+
folder_name = parts[1] if len(parts) > 1 else ""
3435

35-
storage_client = storage.Client()
36-
bucket = storage_client.bucket(bucket_name)
36+
storage_client = storage.Client()
37+
bucket = storage_client.bucket(bucket_name)
3738

38-
source_file_path = f"./{video_path}"
39-
destination_blob_name = os.path.join(folder_name, "videos", video_path)
39+
source_file_path = f"./{video_path}"
40+
destination_blob_name = os.path.join(folder_name, "videos", video_path)
4041

41-
blob = bucket.blob(destination_blob_name)
42+
blob = bucket.blob(destination_blob_name)
4243

43-
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
44-
blob.upload_from_filename(source_file_path)
45-
max_logging.log(f"Upload complete {source_file_path}.")
44+
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
45+
blob.upload_from_filename(source_file_path)
46+
max_logging.log(f"Upload complete {source_file_path}.")
47+
48+
except Exception as e:
49+
max_logging.log(f"An error occurred: {e}")
4650

47-
except Exception as e:
48-
max_logging.log(f"An error occurred: {e}")
4951

5052
def delete_file(file_path: str):
5153
if os.path.exists(file_path):
52-
try:
53-
os.remove(file_path)
54-
max_logging.log(f"Successfully deleted file: {file_path}")
55-
except OSError as e:
56-
max_logging.log(f"Error deleting file '{file_path}': {e}")
54+
try:
55+
os.remove(file_path)
56+
max_logging.log(f"Successfully deleted file: {file_path}")
57+
except OSError as e:
58+
max_logging.log(f"Error deleting file '{file_path}': {e}")
5759
else:
58-
max_logging.log(f"The file '{file_path}' does not exist.")
60+
max_logging.log(f"The file '{file_path}' does not exist.")
61+
5962

6063
jax.config.update("jax_use_shardy_partitioner", True)
6164

65+
6266
def inference_generate_video(config, pipeline, filename_prefix=""):
6367
s0 = time.perf_counter()
6468
prompt = [config.prompt] * config.global_batch_size_to_train_on
@@ -88,6 +92,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
8892
delete_file(f"./{video_path}")
8993
return
9094

95+
9196
def run(config, pipeline=None, filename_prefix=""):
9297
print("seed: ", config.seed)
9398
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,7 @@ def __init__(
747747
precision=precision,
748748
bias_init=nnx.with_partitioning(
749749
nnx.initializers.zeros,
750-
(
751-
"embed",
752-
),
750+
("embed",),
753751
),
754752
)
755753

@@ -763,9 +761,7 @@ def __init__(
763761
precision=precision,
764762
bias_init=nnx.with_partitioning(
765763
nnx.initializers.zeros,
766-
(
767-
"embed",
768-
),
764+
("embed",),
769765
),
770766
)
771767

@@ -779,9 +775,7 @@ def __init__(
779775
precision=precision,
780776
bias_init=nnx.with_partitioning(
781777
nnx.initializers.zeros,
782-
(
783-
"embed",
784-
),
778+
("embed",),
785779
),
786780
)
787781

@@ -795,9 +789,7 @@ def __init__(
795789
precision=precision,
796790
bias_init=nnx.with_partitioning(
797791
nnx.initializers.zeros,
798-
(
799-
"heads",
800-
),
792+
("heads",),
801793
),
802794
)
803795

@@ -813,9 +805,7 @@ def __init__(
813805
dtype=dtype,
814806
scale_init=nnx.with_partitioning(
815807
nnx.initializers.ones,
816-
(
817-
"norm",
818-
),
808+
("norm",),
819809
),
820810
param_dtype=weights_dtype,
821811
)
@@ -826,9 +816,7 @@ def __init__(
826816
dtype=dtype,
827817
scale_init=nnx.with_partitioning(
828818
nnx.initializers.ones,
829-
(
830-
"norm",
831-
),
819+
("norm",),
832820
),
833821
param_dtype=weights_dtype,
834822
)
@@ -850,8 +838,12 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
850838
return xq_out, xk_out
851839

852840
def __call__(
853-
self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None,
854-
deterministic: bool = True, rngs: nnx.Rngs = None,
841+
self,
842+
hidden_states: jax.Array,
843+
encoder_hidden_states: jax.Array = None,
844+
rotary_emb: Optional[jax.Array] = None,
845+
deterministic: bool = True,
846+
rngs: nnx.Rngs = None,
855847
) -> jax.Array:
856848
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
857849
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,25 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_
6767
case GradientCheckpointType.FULL:
6868
return None
6969
case GradientCheckpointType.OFFLOAD_MATMUL_WITHOUT_BATCH:
70-
return cp.offload_dot_with_no_batch_dims(
71-
offload_src="device", offload_dst="pinned_host"
72-
)
70+
return cp.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host")
7371
case GradientCheckpointType.CUSTOM:
7472
policy = cp.save_and_offload_only_these_names(
75-
names_which_can_be_saved=names_which_can_be_saved,
76-
names_which_can_be_offloaded=names_which_can_be_offloaded,
77-
offload_src="device",
78-
offload_dst="pinned_host"
79-
)
73+
names_which_can_be_saved=names_which_can_be_saved,
74+
names_which_can_be_offloaded=names_which_can_be_offloaded,
75+
offload_src="device",
76+
offload_dst="pinned_host",
77+
)
8078
return policy
8179
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
8280
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
8381

84-
def apply(self, module: nnx.Module, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], static_argnums=()) -> nnx.Module:
82+
def apply(
83+
self,
84+
module: nnx.Module,
85+
names_which_can_be_saved: list = [],
86+
names_which_can_be_offloaded: list = [],
87+
static_argnums=(),
88+
) -> nnx.Module:
8589
"""
8690
Applies a gradient checkpoint policy to a module
8791
if no policy is needed, it will return the module as is
@@ -95,9 +99,4 @@ def apply(self, module: nnx.Module, names_which_can_be_saved: list = [], names_w
9599
policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded)
96100
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
97101
return module
98-
return nnx.remat( # pylint: disable=invalid-name
99-
module,
100-
prevent_cse=False,
101-
policy=policy,
102-
static_argnums=static_argnums
103-
)
102+
return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ def __init__(
236236
)
237237

238238
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
239-
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
239+
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
240240
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
241241
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
242-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
242+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
243243

244244

245245
class WanTransformerBlock(nnx.Module):
@@ -281,7 +281,7 @@ def __init__(
281281
weights_dtype=weights_dtype,
282282
precision=precision,
283283
attention_kernel=attention,
284-
dropout=dropout
284+
dropout=dropout,
285285
)
286286

287287
# 1. Cross-attention
@@ -299,7 +299,7 @@ def __init__(
299299
weights_dtype=weights_dtype,
300300
precision=precision,
301301
attention_kernel=attention,
302-
dropout=dropout
302+
dropout=dropout,
303303
)
304304
assert cross_attn_norm is True
305305
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -313,15 +313,24 @@ def __init__(
313313
dtype=dtype,
314314
weights_dtype=weights_dtype,
315315
precision=precision,
316-
dropout=dropout
316+
dropout=dropout,
317317
)
318318
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
319319

320320
key = rngs.params()
321321
self.adaln_scale_shift_table = nnx.Param(
322-
jax.random.normal(key, (1, 6, dim)) / dim**0.5,)
322+
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
323+
)
323324

324-
def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,):
325+
def __call__(
326+
self,
327+
hidden_states: jax.Array,
328+
encoder_hidden_states: jax.Array,
329+
temb: jax.Array,
330+
rotary_emb: jax.Array,
331+
deterministic: bool = True,
332+
rngs: nnx.Rngs = None,
333+
):
325334
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
326335
(self.adaln_scale_shift_table + temb), 6, axis=1
327336
)
@@ -331,13 +340,19 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
331340
# 1. Self-attention
332341
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
333342
attn_output = self.attn1(
334-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb, deterministic=deterministic, rngs=rngs
343+
hidden_states=norm_hidden_states,
344+
encoder_hidden_states=norm_hidden_states,
345+
rotary_emb=rotary_emb,
346+
deterministic=deterministic,
347+
rngs=rngs,
335348
)
336349
hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype)
337350

338351
# 2. Cross-attention
339352
norm_hidden_states = self.norm2(hidden_states)
340-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs)
353+
attn_output = self.attn2(
354+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
355+
)
341356
hidden_states = hidden_states + attn_output
342357

343358
# 3. Feed-forward
@@ -380,7 +395,7 @@ def __init__(
380395
attention: str = "dot_product",
381396
remat_policy: str = "None",
382397
names_which_can_be_saved: list = [],
383-
names_which_can_be_offloaded: list = []
398+
names_which_can_be_offloaded: list = [],
384399
):
385400
inner_dim = num_attention_heads * attention_head_dim
386401
out_channels = out_channels or in_channels
@@ -417,7 +432,7 @@ def __init__(
417432

418433
# 3. Transformer blocks
419434
@nnx.split_rngs(splits=num_layers)
420-
@nnx.vmap(in_axes=0, out_axes=0, transform_metadata= {nnx.PARTITION_NAME: "layers_per_stage"} )
435+
@nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"})
421436
def init_block(rngs):
422437
return WanTransformerBlock(
423438
rngs=rngs,
@@ -496,7 +511,9 @@ def scan_fn(carry, block):
496511
new_carry = (hidden_states, rngs_carry)
497512
return new_carry, None
498513

499-
rematted_block_forward = self.gradient_checkpoint.apply(scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded)
514+
rematted_block_forward = self.gradient_checkpoint.apply(
515+
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded
516+
)
500517
initial_carry = (hidden_states, rngs)
501518
final_carry, _ = nnx.scan(
502519
rematted_block_forward,

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _load_kwargs(self, argv: list[str]):
121121
@staticmethod
122122
def wan_init(raw_keys):
123123
if not any("layers_per_stage" in inner_tuple for inner_tuple in raw_keys["logical_axis_rules"]):
124-
raw_keys["logical_axis_rules"]+= (("layers_per_stage", None),)
124+
raw_keys["logical_axis_rules"] += (("layers_per_stage", None),)
125125
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
126126
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
127127
if transformer_pretrained_model_name_or_path == "":

0 commit comments

Comments
 (0)