Skip to content

Commit 2bad084

Browse files
committed
Formatting
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 8c45e6e commit 2bad084

5 files changed

Lines changed: 75 additions & 72 deletions

File tree

end_to_end/tpu/eval_assert.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
limitations under the License.
1515
"""
1616

17+
"""
18+
Example to run
19+
python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100
20+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
21+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
22+
"""
23+
24+
25+
1726
# pylint: skip-file
1827
"""Reads and asserts over target values"""
1928
from absl import app
@@ -34,11 +43,12 @@ def get_last_n_data(metrics_file, target, n=10):
3443
return last_n_data
3544

3645

37-
def test_final_loss(metrics_file, target_loss):
46+
def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
3847
target_loss = float(target_loss)
48+
num_samples = int(num_samples_str)
3949
with open(metrics_file, "r", encoding="utf8") as _:
4050
use_last_n_data = 10
41-
last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data)
51+
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
4252
avg_last_n_data = sum(last_n_data) / len(last_n_data)
4353
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
4454
print(f"Target loss is {target_loss}")
@@ -61,8 +71,9 @@ def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"
6171
print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.")
6272
print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s")
6373

64-
assert avg_last_n_step_time < max_avg_step_time, \
65-
f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
74+
assert (
75+
avg_last_n_step_time < max_avg_step_time
76+
), f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
6677
print("Average step time test passed.")
6778

6879

@@ -82,8 +93,9 @@ def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"):
8293
print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.")
8394
print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}")
8495

85-
assert avg_last_n_tflops > min_avg_tflops, \
86-
f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
96+
assert (
97+
avg_last_n_tflops > min_avg_tflops
98+
), f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
8799
print("Average TFLOPs/sec test passed.")
88100

89101

@@ -118,4 +130,4 @@ def main(argv: Sequence[str]) -> None:
118130

119131

120132
if __name__ == "__main__":
121-
app.run(main)
133+
app.run(main)

src/maxdiffusion/max_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,11 +564,12 @@ def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSe
564564
total_flops = (total_flops * 3 if train else total_flops) / 10**12
565565
return total_flops
566566

567-
def get_train_step_partial_with_signature(train_step:Callable, pipeline:object, params:Dict, config:object)->Callable:
567+
568+
def get_train_step_partial_with_signature(train_step: Callable, pipeline: object, params: Dict, config: object) -> Callable:
568569
partial_train = partial(train_step, pipeline=pipeline, params=params, config=config)
569570
partial_train.__name__ = "train_step"
570571
return partial_train
571-
572+
572573

573574
def calculate_num_params_from_pytree(params):
574575
"""Calculates number of parameters from a pytree"""

src/maxdiffusion/trainers/base_stable_diffusion_trainer.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424

2525
# Define a filename for logging
2626

27-
def _log_to_file(message: str, log_file:str=""):
28-
"""Appends a message to the global log file with a timestamp."""
29-
timestamp = time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime())
30-
full_message = f"[{timestamp}] {message}\n"
31-
if log_file:
32-
with open(log_file, 'a') as f:
33-
f.write(full_message)
34-
max_logging.log(full_message.strip())
35-
27+
28+
def _log_to_file(message: str, log_file: str = ""):
29+
"""Appends a message to the global log file with a timestamp."""
30+
timestamp = time.strftime("%Y-%m-%d %H:%M:%S %Z", time.localtime())
31+
full_message = f"[{timestamp}] {message}\n"
32+
if log_file:
33+
with open(log_file, "a") as f:
34+
f.write(full_message)
35+
max_logging.log(full_message.strip())
3636

3737

3838
class BaseStableDiffusionTrainer(BaseStableDiffusionCheckpointer):
@@ -80,32 +80,28 @@ def get_data_shardings(self):
8080
@abstractmethod
8181
def create_scheduler(self, pipeline, params):
8282
pass
83-
83+
8484
def _time_and_log_call(
85-
self,
86-
func_obj: Callable[..., Any],
87-
*func_args: Any,
88-
description: str = "",
89-
**func_kwargs: Any
90-
) -> Any:
85+
self, func_obj: Callable[..., Any], *func_args: Any, description: str = "", **func_kwargs: Any
86+
) -> Any:
9187
"""
9288
Times a function call, logs its duration, and returns its result.
9389
"""
9490
if not description:
95-
if hasattr(func_obj, '__name__'):
91+
if hasattr(func_obj, "__name__"):
9692
description = func_obj.__name__
97-
elif hasattr(func_obj, '__call__') and hasattr(type(func_obj), '__name__'):
93+
elif hasattr(func_obj, "__call__") and hasattr(type(func_obj), "__name__"):
9894
description = type(func_obj).__name__
9995
log_file = ""
100-
96+
10197
if self.config.write_timing_metrics and self.config.timing_metrics_file:
10298
log_file = self.config.get.timing_metrics_file
10399
_log_to_file(f"Starting: {description}...", log_file=log_file)
104-
start_time = time.perf_counter() # Use perf_counter for more precise duration measurement
100+
start_time = time.perf_counter() # Use perf_counter for more precise duration measurement
105101
result = func_obj(*func_args, **func_kwargs)
106102
end_time = time.perf_counter()
107103
duration = end_time - start_time
108-
_log_to_file(f"Finished: {description} - Duration: {duration:.4f} seconds",log_file=log_file)
104+
_log_to_file(f"Finished: {description} - Duration: {duration:.4f} seconds", log_file=log_file)
109105
return result
110106

111107
def calculate_tflops(self, pipeline, params):
@@ -129,7 +125,7 @@ def start_training(self):
129125
pipeline=pipeline,
130126
params=params,
131127
checkpoint_item_name="vae_state",
132-
is_training=False
128+
is_training=False,
133129
)
134130

135131
train_states["vae_state"] = vae_state
@@ -147,13 +143,13 @@ def start_training(self):
147143
state_shardings["text_encoder_state_shardings"] = text_encoder_state_mesh_shardings
148144
if hasattr(pipeline, "text_encoder_2"):
149145
text_encoder_2_state, text_encoder_2_state_mesh_shardings = self._time_and_log_call(
150-
self.create_text_encoder_2_state,
151-
# Arguments for create_text_encoder_2_state
152-
pipeline=pipeline,
153-
params=params,
154-
checkpoint_item_name="text_encoder_2_state",
155-
is_training=self.config.train_text_encoder,
156-
)
146+
self.create_text_encoder_2_state,
147+
# Arguments for create_text_encoder_2_state
148+
pipeline=pipeline,
149+
params=params,
150+
checkpoint_item_name="text_encoder_2_state",
151+
is_training=self.config.train_text_encoder,
152+
)
157153
train_states["text_encoder_2_state"] = text_encoder_2_state
158154
state_shardings["text_encoder_2_state_shardings"] = text_encoder_2_state_mesh_shardings
159155

@@ -167,17 +163,9 @@ def start_training(self):
167163
self.per_device_tflops = per_device_tflops
168164

169165
# Load dataset
170-
data_iterator = self._time_and_log_call(
171-
self.load_dataset,
172-
pipeline,
173-
params,
174-
train_states
175-
)
166+
data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states)
176167
if self.config.dataset_type == "grain":
177-
data_iterator = self._time_and_log_call(
178-
self.restore_data_iterator_state,
179-
data_iterator=data_iterator
180-
)
168+
data_iterator = self._time_and_log_call(self.restore_data_iterator_state, data_iterator=data_iterator)
181169

182170
unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self._time_and_log_call(
183171
self.create_unet_state,
@@ -196,13 +184,12 @@ def start_training(self):
196184
data_shardings = self.get_data_shardings()
197185
# Compile train_step
198186
p_train_step = self._time_and_log_call(
199-
self.compile_train_step,
200-
pipeline, params, train_states, state_shardings, data_shardings
201-
)
187+
self.compile_train_step, pipeline, params, train_states, state_shardings, data_shardings
188+
)
202189
# Start training
203-
train_states = self._time_and_log_call(self.training_loop,
204-
p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler
190+
train_states = self._time_and_log_call(
191+
self.training_loop, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler
205192
)
206193
# 6. save final checkpoint
207194
# Hook
208-
self._time_and_log_call(self.post_training_steps,pipeline, params, train_states)
195+
self._time_and_log_call(self.post_training_steps, pipeline, params, train_states)

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,12 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
286286
guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance_scale, dtype=self.config.activations_dtype)
287287
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
288288
train_step_partial = partial(
289-
_train_step,
290-
guidance_vec=guidance_vec,
291-
pipeline=pipeline,
292-
scheduler=train_states["scheduler"],
293-
config=self.config,
294-
)
289+
_train_step,
290+
guidance_vec=guidance_vec,
291+
pipeline=pipeline,
292+
scheduler=train_states["scheduler"],
293+
config=self.config,
294+
)
295295
train_step_partial.__name__ = "train_step"
296296
p_train_step = jax.jit(
297297
train_step_partial,

src/maxdiffusion/utils/dynamic_modules_utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from huggingface_hub import HfFolder, hf_hub_download, model_info
2929
import huggingface_hub
3030
from packaging import version
31+
3132
cached_download = None
3233

3334
from .. import __version__
@@ -42,20 +43,22 @@
4243
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4344

4445
# https://github.com/huggingface/huggingface_hub/releases/tag/v0.26.0
45-
# `cached_download(), url_to_filename(), filename_to_url() methods are now completely removed.
46+
# `cached_download(), url_to_filename(), filename_to_url() methods are now completely removed.
4647
# From now on, you will have to use hf_hub_download() to benefit from the new cache layout.`
47-
if hasattr(huggingface_hub, '__version__'):
48-
current_version = version.parse(huggingface_hub.__version__)
49-
target_version = version.parse("0.26.0")
50-
51-
if current_version < target_version:
52-
try:
53-
from huggingface_hub import cached_download
54-
55-
except ImportError:
56-
logger.error(f"huggingface_hub version {current_version} is below 0.26.0, but 'cached_download' could not be imported. It might have been removed or deprecated in this version as well.")
48+
if hasattr(huggingface_hub, "__version__"):
49+
current_version = version.parse(huggingface_hub.__version__)
50+
target_version = version.parse("0.26.0")
51+
52+
if current_version < target_version:
53+
try:
54+
from huggingface_hub import cached_download
55+
56+
except ImportError:
57+
logger.error(
58+
f"huggingface_hub version {current_version} is below 0.26.0, but 'cached_download' could not be imported. It might have been removed or deprecated in this version as well."
59+
)
5760
else:
58-
logger.error("Could not determine huggingface_hub version. Unable to conditionally import 'cached_download'.")
61+
logger.error("Could not determine huggingface_hub version. Unable to conditionally import 'cached_download'.")
5962

6063

6164
def get_diffusers_versions():

0 commit comments

Comments
 (0)