diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index a5e1bfc2f..89ac3764c 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -194,7 +194,7 @@ def load_diffusers_checkpoint(self): clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( @@ -263,7 +263,7 @@ def load_checkpoint(self, step=None, scheduler_class=None): self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype ) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True ) vae = FlaxAutoencoderKL.from_config( diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 454f65785..ce0ae5169 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -18,7 +18,7 @@ import tensorflow as tf import tensorflow.experimental.numpy as tnp from datasets import load_dataset, load_from_disk - +import jax from maxdiffusion import multihost_dataloading AUTOTUNE = tf.data.AUTOTUNE @@ -65,8 +65,13 @@ def make_tf_iterator( ) if config.cache_latents_text_encoder_outputs: train_ds.save_to_disk(config.dataset_save_location) - train_ds.cleanup_cache_files() - + # Only process 0 should attempt to clean up cache files + if jax.process_index() == 0: + try: + train_ds.cleanup_cache_files() + except FileNotFoundError: + # Ignore FileNotFoundError as files may have been cleaned up by another process + pass train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, dataloading_host_count) train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index aaa929c59..fe6cc09af 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -26,7 +26,6 @@ import os from pathlib import Path import subprocess - import numpy as np import flax diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 97a31ebe9..505a4f4ff 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -86,7 +86,7 @@ def setup(self): self.linear1 = nn.Dense( self.dim * 3 + self.mlp_hidden_dim, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -96,7 +96,7 @@ def setup(self): self.linear2 = nn.Dense( self.dim, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -209,7 +209,7 @@ def setup(self): int(self.dim * self.mlp_ratio), use_bias=True, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -218,8 +218,8 @@ def setup(self): nn.Dense( self.dim, use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -240,7 +240,7 @@ def setup(self): int(self.dim * self.mlp_ratio), use_bias=True, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -249,8 +249,8 @@ def setup(self): nn.Dense( self.dim, use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -483,6 +483,9 @@ def __call__( ): hidden_states = self.img_in(hidden_states) timestep = self.timestep_embedding(timestep, 256) + + timestep = nn.with_logical_constraint(timestep, ("activation_batch", None)) + if self.guidance_embeds: guidance = self.timestep_embedding(guidance, 256) else: @@ -492,6 +495,9 @@ def __call__( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) + + temb = nn.with_logical_constraint(temb, ("activation_batch", None)) + encoder_hidden_states = self.txt_in(encoder_hidden_states) if txt_ids.ndim == 3: txt_ids = txt_ids[0] @@ -501,7 +507,7 @@ def __call__( ids = jnp.concatenate((txt_ids, img_ids), axis=0) ids = nn.with_logical_constraint(ids, ("activation_batch", None)) image_rotary_emb = self.pe_embedder(ids) - image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) + image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, (None, None)) for double_block in self.double_blocks: hidden_states, encoder_hidden_states = double_block( diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 32139faef..f3ca99cb2 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -252,6 +252,7 @@ def load_dataset(self, pipeline, params, train_states): t5_tokenizer=pipeline.t5_tokenizer, clip_text_encoder=pipeline.clip_encoder, t5_text_encoder=pipeline.t5_encoder, + max_sequence_length=config.max_sequence_length, encode_in_batches=True, encode_batch_size=16, ) @@ -348,9 +349,13 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera example_batch = load_next_batch(data_iterator, example_batch, self.config) example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} - with jax.profiler.StepTraceAnnotation("train", step_num=step): + if self.config.profiler == 'nsys': with self.mesh: flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) + else: + with jax.profiler.StepTraceAnnotation("train", step_num=step): + with self.mesh: + flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now()