diff --git a/.github/actions/setup-miniconda/action.yml b/.github/actions/setup-miniconda/action.yml index cc755d3aa..5f85af7a2 100644 --- a/.github/actions/setup-miniconda/action.yml +++ b/.github/actions/setup-miniconda/action.yml @@ -7,7 +7,7 @@ inputs: description: If set to any value, dont use sudo to clean the workspace required: false type: string - default: "3.9" + default: "3.10" miniconda-version: description: Miniconda version to install required: false diff --git a/.github/workflows/XLML.yml b/.github/workflows/XLML.yml new file mode 100644 index 000000000..c9a3bf69b --- /dev/null +++ b/.github/workflows/XLML.yml @@ -0,0 +1,22 @@ +name: Add Testgrid Link to PR + +on: + pull_request: + types: [opened, synchronize] + +jobs: + add_testgrid_link: + runs-on: ubuntu-latest + permissions: + pull-requests: write + steps: + - name: Add link to PR description + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + curl -X POST \ + -H "Authorization: token $GITHUB_TOKEN" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/issues/$PR_NUMBER/comments" \ + -d '{ "body": "e2e testgrid: https://8bcf50593faf4ea38060e236169827e5-dot-us-central1.composer.googleusercontent.com/dags/maxdiffusion_tpu_e2e/grid" }' \ No newline at end of file diff --git a/src/maxdiffusion/mllog_utils.py b/src/maxdiffusion/mllog_utils.py deleted file mode 100644 index 2639697ac..000000000 --- a/src/maxdiffusion/mllog_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -""" - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -"""Utils that relevant to mllog for mlperf submission compliance.""" -import jax -from mlperf_logging import mllog -import numpy as np - -mllogger = mllog.get_mllogger() - - -def train_init_start(config): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.event(mllog.constants.CACHE_CLEAR) - mllogger.start(mllog.constants.INIT_START) - - -def train_init_stop(config): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.end(mllog.constants.INIT_STOP) - - -def train_run_start(config): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.start(mllog.constants.RUN_START) - - -def train_run_end(config): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.end(mllog.constants.RUN_STOP, metadata={"status": "success"}) - - -def train_init_print(config, device: str = "tpu-v5p"): - """an initial mllog for mlperf sumbission compliance check.""" - if jax.process_index() == 0 and config.enable_mllog: - mllogger.event(mllog.constants.SUBMISSION_ORG, "Google") - mllogger.event(mllog.constants.SUBMISSION_PLATFORM, device) - mllogger.event(mllog.constants.SUBMISSION_STATUS, mllog.constants.CLOUD) - mllogger.event(mllog.constants.SUBMISSION_DIVISION, mllog.constants.CLOSED) - mllogger.event(mllog.constants.SUBMISSION_BENCHMARK, mllog.constants.STABLE_DIFFUSION) - mllogger.event(mllog.constants.GRADIENT_ACCUMULATION_STEPS, 1) - mllogger.event(mllog.constants.GLOBAL_BATCH_SIZE, config.per_device_batch_size * jax.device_count()) - - mllogger.event(mllog.constants.OPT_NAME, mllog.constants.ADAMW) - mllogger.event(mllog.constants.OPT_ADAMW_BETA_1, config.adam_b1) - mllogger.event(mllog.constants.OPT_ADAMW_BETA_2, config.adam_b2) - mllogger.event(mllog.constants.OPT_ADAMW_EPSILON, config.adam_eps) - mllogger.event(mllog.constants.OPT_ADAMW_WEIGHT_DECAY, config.adam_weight_decay) - - mllogger.event(mllog.constants.OPT_BASE_LR, config.learning_rate) - mllogger.event( - mllog.constants.OPT_LR_WARMUP_STEPS, int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) - ) - - # Training: a subset of laion-400m - # Validation: a subset of coco-2014 validation - mllogger.event(mllog.constants.TRAIN_SAMPLES, 6513144) - mllogger.event(mllog.constants.EVAL_SAMPLES, 30000) - - mllogger.event(mllog.constants.SEED, config.seed) - - -def train_step_start(config, step): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.start( - mllog.constants.BLOCK_START, - value="training_step", - metadata={ - "step_num": step, - }, - ) - - -def train_step_end(config, step, loss, lr): - if jax.process_index() == 0 and config.enable_mllog: - mllogger.end( - mllog.constants.BLOCK_STOP, - value="training_step", - metadata={ - "step_num": step, - "loss": loss, - "lr": lr, - }, - ) - - -def maybe_train_step_log(config, start_step, step, metric, train_log_interval: int = 100): - if step > start_step and step % train_log_interval == 0 or step == config.max_train_steps - 1 and config.enable_mllog: - # convert the jax array to a numpy array for mllog JSON encoding - loss = np.asarray(metric["scalar"]["learning/loss"]) - lr = np.asarray(metric["scalar"]["learning/current_learning_rate"]) - - train_step_end(config, step, loss, lr) - # start new tracking except the last step - if step < config.max_train_steps - 1: - train_step_start(config, step) diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index e3b161039..c37fdfa02 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -21,7 +21,6 @@ from maxdiffusion import ( max_logging, pyconfig, - mllog_utils, ) from maxdiffusion.train_utils import ( @@ -39,7 +38,6 @@ def train(config): def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) config = pyconfig.config - mllog_utils.train_init_start(config) validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") train(config) diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index cd8021556..64b0cd3bc 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -21,7 +21,6 @@ from maxdiffusion import ( max_logging, pyconfig, - mllog_utils, ) from maxdiffusion.trainers.sdxl_trainer import StableDiffusionXLTrainer @@ -39,7 +38,6 @@ def train(config): def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) config = pyconfig.config - mllog_utils.train_init_start(config) validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") train(config)