Skip to content

Commit c0ba5c1

Browse files
linting
1 parent 4245b24 commit c0ba5c1

13 files changed

Lines changed: 283 additions & 989 deletions

end_to_end/tpu/eval_assert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""
2323

2424

25-
2625
# pylint: skip-file
2726
"""Reads and asserts over target values"""
2827
from absl import app
@@ -47,7 +46,7 @@ def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
4746
target_loss = float(target_loss)
4847
num_samples = int(num_samples_str)
4948
with open(metrics_file, "r", encoding="utf8") as _:
50-
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
49+
last_n_data = get_last_n_data(metrics_file, "learning/loss", num_samples)
5150
avg_last_n_data = sum(last_n_data) / len(last_n_data)
5251
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
5352
print(f"Target loss is {target_loss}")

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 226 deletions
This file was deleted.

src/maxdiffusion/models/flux/util.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from jax import numpy as jnp
1212
from safetensors import safe_open
1313

14-
from ..modeling_flax_pytorch_utils import (
15-
rename_key,
16-
rename_key_and_reshape_tensor,
17-
torch2jax
18-
)
14+
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax)
1915
from maxdiffusion import max_logging
2016

2117

@@ -36,6 +32,7 @@ class FluxParams:
3632
rngs: Array
3733
param_dtype: DTypeLike
3834

35+
3936
@dataclass
4037
class ModelSpec:
4138
params: FluxParams

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
logger = logging.get_logger(__name__)
3131

32+
3233
def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict):
3334
"""
3435
expected_pytree: dict - a pytree that comes from initializing the model.
@@ -54,6 +55,7 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict):
5455
else:
5556
max_logging.log(f"key: {key} not found...")
5657

58+
5759
def torch2jax(torch_tensor: torch.Tensor) -> Array:
5860
is_bfloat16 = torch_tensor.dtype == torch.bfloat16
5961
if is_bfloat16:
@@ -67,6 +69,7 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array:
6769
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
6870
return jax_array
6971

72+
7073
def rename_key(key):
7174
regex = r"\w+[.]\d+"
7275
pats = re.findall(regex, key)
@@ -132,7 +135,7 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
132135
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
133136
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
134137
return renamed_pt_tuple_key, pt_tensor
135-
138+
136139
# 3d conv layer
137140
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
138141
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5:

0 commit comments

Comments
 (0)