diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 5f64d4880..8f1e2654e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,7 +15,6 @@ """ from abc import ABC -from flax import nnx from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils @@ -42,7 +41,7 @@ def _create_optimizer(self, model, config, learning_rate): learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) tx = max_utils.create_optimizer(config, learning_rate_scheduler) - return nnx.Optimizer(model, tx), learning_rate_scheduler + return tx, learning_rate_scheduler def load_wan_configs_from_orbax(self, step): max_logging.log("Restoring stable diffusion configs") diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 911127896..b552b0621 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -54,6 +54,7 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +flash_min_seq_length: 4096 flash_block_sizes: {} # Use on v6e @@ -126,15 +127,17 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], + ['activation_batch', 'data'], ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], - ['activation_batch', 'data'], ['mlp','tensor'], ['embed','fsdp'], + ['heads', 'tensor'], ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], - ['conv_in', 'fsdp'], + ['conv_out', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -182,6 +185,14 @@ transform_images_num_proc: 4 reuse_example_batch: False enable_data_shuffling: True +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters +remat_policy: "NONE" + # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 # enables one replica to read the ckpt then broadcast to the rest @@ -196,7 +207,7 @@ max_train_steps: 1500 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 1 +per_device_batch_size: 1.0 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d3c8d47cf..a9bcf366c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -29,15 +29,9 @@ def run(config, pipeline=None, filename_prefix=""): pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - # If global_batch_size % jax.device_count is not 0, use FSDP sharding. - global_batch_size = config.global_batch_size - if global_batch_size != 0: - batch_multiplier = global_batch_size - else: - batch_multiplier = jax.device_count() * config.per_device_batch_size - - prompt = [config.prompt] * batch_multiplier - negative_prompt = [config.negative_prompt] * batch_multiplier + # Using global_batch_size_to_train_on so not to create more config variables + prompt = [config.prompt] * config.global_batch_size_to_train_on + negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index ce0ae5169..562d5c718 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -117,13 +117,16 @@ def make_tfrecord_iterator( check out preparation script maxdiffusion/pedagogical_examples/to_tfrecords.py """ - # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + + # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. + is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) + if ( config.cache_latents_text_encoder_outputs - and os.path.isdir(config.dataset_save_location) + and is_dataset_dir_valid and "load_tfrecord_cached" in config.get_keys() and config.load_tfrecord_cached ): diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3099a5bc0..fe86e08c4 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -18,6 +18,7 @@ import flax.linen as nn from flax import nnx import jax +from jax.ad_checkpoint import checkpoint_name from jax.sharding import PartitionSpec import jax.numpy as jnp from jax.experimental import shard_map @@ -187,30 +188,6 @@ def _tpu_flash_attention( value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) - axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) - named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) - - shard_head_size = mesh.shape["tensor"] - - @functools.partial( - jax.jit, - static_argnames=["multi_head_mask", "shard_head_size"], - ) - def wrap_splash_kernel(multi_head_mask, shard_head_size=1): - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=shard_head_size, # the sizes of the axis is sharding over heads - q_seq_shards=1, # the sizes of the axis is sharding over seq_len - block_sizes=block_sizes, - ) - return splash_kernel - - mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) - - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) - segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) @functools.partial( shard_map.shard_map, @@ -219,12 +196,21 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): q_axis_names, kv_axis_names, kv_axis_names, - segment_axis_names_splash_kernel, ), out_specs=q_axis_names, check_rep=False, ) - def wrap_flash_attention(query, key, value, splash_kernel): + def wrap_flash_attention(query, key, value): + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + # make_splash_mha is wrapped around shardmap and seq and head is already + # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + ) attention_output = jax.vmap(splash_kernel)(query, key, value) return attention_output @@ -236,7 +222,7 @@ def wrap_flash_attention(query, key, value, splash_kernel): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value, splash_kernel) + x = wrap_flash_attention(query, key, value) x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) @@ -632,7 +618,7 @@ def __init__( use_memory_efficient_attention: bool = False, split_head_dim: bool = False, attention_kernel: str = "flash", - flash_min_seq_length: int = 4096, + flash_min_seq_length: int = 0, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, @@ -809,12 +795,16 @@ def __call__( query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + query_proj = checkpoint_name(query_proj, "query_proj") + key_proj = checkpoint_name(key_proj, "key_proj") + value_proj = checkpoint_name(value_proj, "value_proj") attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) - + attn_output = checkpoint_name(attn_output, "attn_output") hidden_states = self.proj_attn(attn_output) return hidden_states diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py new file mode 100644 index 000000000..28f637c23 --- /dev/null +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -0,0 +1,93 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from enum import Enum, auto +from typing import Optional + +import jax +from jax import checkpoint_policies as cp +from flax import nnx + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +# This class only works with NNX modules. +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + ATTN = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.ATTN: + return cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host" + ) + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nnx.Module) -> nnx.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nnx.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index b1ae70b7a..6588929b1 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -18,6 +18,7 @@ import math import jax import jax.numpy as jnp +from jax.sharding import PartitionSpec from flax import nnx import numpy as np from .... import common_types @@ -31,6 +32,7 @@ ) from ...normalization_flax import FP32LayerNorm from ...attention_flax import FlaxWanAttention +from ...gradient_checkpoint import GradientCheckpointType BlockSizes = common_types.BlockSizes @@ -170,6 +172,15 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + None, + "mlp", + "embed", + ), + ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -217,8 +228,9 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", + None, "embed", + "mlp", ), ), ) @@ -300,12 +312,14 @@ def __init__( self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) key = rngs.params() - self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) + self.adaln_scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.scale_shift_table + temb), 6, axis=1 + (self.adaln_scale_shift_table + temb), 6, axis=1 ) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) @@ -356,6 +370,7 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", + remat_policy: str = "None", ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -374,13 +389,7 @@ def __init__( precision=precision, kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), - ( - None, - None, - None, - None, - "conv_out", - ), + (None, None, None, None, "conv_out"), ), ) @@ -417,6 +426,8 @@ def init_block(rngs): attention=attention, ) + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.blocks = init_block(rngs) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) @@ -452,6 +463,7 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) + hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -469,8 +481,9 @@ def scan_fn(carry, block): return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + rematted_block_forward = self.gradient_checkpoint.apply(scan_fn) final_carry = nnx.scan( - scan_fn, + rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=nnx.Carry, diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 5a27591d6..628207a9a 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + import os import json import torch @@ -225,6 +241,7 @@ def load_base_wan_transformer( for pt_key, tensor in tensors.items(): renamed_pt_key = rename_key(pt_key) renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9ca2e03b9..8d2f2cd3b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -78,6 +78,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["attention"] = config.attention wan_config["precision"] = get_precision(config) wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + wan_config["remat_policy"] = config.remat_policy + wan_config["flash_min_seq_length"] = config.flash_min_seq_length # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. @@ -414,7 +416,8 @@ def __call__( ) data_sharding = NamedSharding(self.mesh, P()) - if len(prompt) % jax.device_count() == 0: + # Using global_batch_size_to_train_on so not to create more config variables + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 92dd2a992..8e758d661 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -137,6 +137,18 @@ def wan_init(raw_keys): else: raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod + def calculate_global_batch_sizes(per_device_batch_size): + num_devices = len(jax.devices()) + if per_device_batch_size < 1: + # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 + global_batch_size_to_load = num_devices + else: + global_batch_size_to_load = int(num_devices * per_device_batch_size) + + global_batch_size_to_train_on = int(num_devices * per_device_batch_size) + return global_batch_size_to_load, global_batch_size_to_train_on + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -181,6 +193,9 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = ( + _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) + ) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d568709f0..3b0b520bf 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -23,6 +23,7 @@ import tensorflow as tf import jax.numpy as jnp import jax +from jax.sharding import PartitionSpec as P from flax import nnx from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning @@ -34,6 +35,18 @@ from maxdiffusion.video_processor import VideoProcessor from maxdiffusion.utils import load_video from skimage.metrics import structural_similarity as ssim +from flax.training import train_state + + +class TrainState(train_state.TrainState): + graphdef: nnx.GraphDef + rest_of_state: nnx.State + + +def _to_array(x): + if not isinstance(x, jax.Array): + x = jnp.asarray(x) + return x def generate_sample(config, pipeline, filename_prefix): @@ -69,8 +82,6 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") - self.global_batch_size = self.config.per_device_batch_size * jax.device_count() - def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -85,6 +96,11 @@ def calculate_tflops(self, pipeline): max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 + def get_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding} + return data_sharding + def load_dataset(self, mesh): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px @@ -115,7 +131,7 @@ def prepare_sample(features): jax.process_index(), jax.process_count(), mesh, - self.global_batch_size, + config.global_batch_size_to_load, feature_description=feature_description, prepare_sample_fn=prepare_sample, ) @@ -135,9 +151,7 @@ def start_training(self): scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) @@ -145,14 +159,24 @@ def start_training(self): print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): + mesh = pipeline.mesh + graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) - graphdef, state = nnx.split((pipeline.transformer, optimizer)) + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + state = TrainState.create( + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state + ) + state = jax.tree.map(_to_array, state) + state_spec = nnx.get_partition_spec(state) + state = jax.lax.with_sharding_constraint(state, state_spec) + state_shardings = nnx.get_named_sharding(state, mesh) + data_shardings = self.get_data_shardings(mesh) writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() - num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) max_utils.add_config_to_summary_writer(self.config, writer) @@ -160,12 +184,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera if jax.process_index() == 0: max_logging.log("***** Running training *****") max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") - max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}") max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") - state = state.to_pure_dict() p_train_step = jax.jit( functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), + in_shardings=(state_shardings, data_shardings, None, None), + out_shardings=(state_shardings, None, None, None), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -194,7 +219,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( self.config.logical_axis_rules ): - state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, example_batch, rng) + state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now() @@ -214,19 +239,22 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera writer.flush() # load new state for trained tranformer - graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) - pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline -def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config): - return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config) +def train_step(state, data, rng, scheduler_state, scheduler, config): + return step_optimizer(state, data, rng, scheduler_state, scheduler, config) -def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config): +def step_optimizer(state, data, rng, scheduler_state, scheduler, config): _, new_rng, timestep_rng = jax.random.split(rng, num=3) - def loss_fn(model): + for k, v in data.items(): + data[k] = v[: config.global_batch_size_to_train_on, :] + + def loss_fn(params): + model = nnx.merge(state.graphdef, params, state.rest_of_state) latents = data["latents"].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) bsz = latents.shape[0] @@ -253,10 +281,8 @@ def loss_fn(model): return loss - model, optimizer = nnx.merge(graphdef, state) - loss, grads = nnx.value_and_grad(loss_fn)(model) - optimizer.update(grads) - state = nnx.state((model, optimizer)) - state = state.to_pure_dict() + grad_fn = nnx.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + new_state = state.apply_gradients(grads=grads) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - return state, scheduler_state, metrics, new_rng + return new_state, scheduler_state, metrics, new_rng