Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
dropout: 0.1

flash_block_sizes: {}
# Use on v6e
Expand Down Expand Up @@ -193,8 +194,14 @@ enable_data_shuffling: True
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
# except for ones that involve batch dimension - that means that all attention and projection
# layers will have gradient checkpoint, but not the backward with respect to the parameters
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
# CUSTOM - set names to offload and save.
remat_policy: "NONE"
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
# xq_out, xk_out, ffn_activation
names_which_can_be_saved: []
names_which_can_be_offloaded: []

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
Expand Down
43 changes: 23 additions & 20 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,24 @@

_re_configuration_file = re.compile(r"config\.(.*)\.json")


class CustomEncoder(json.JSONEncoder):
"""
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
"""
def default(self, o):
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
if isinstance(o, type(jnp.dtype('bfloat16'))):
return str(o)
# Add fallbacks for other numpy types if needed
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
# Let the base class default method raise the TypeError for other types
return super().default(o)
"""
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
"""

def default(self, o):
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
if isinstance(o, type(jnp.dtype("bfloat16"))):
return str(o)
# Add fallbacks for other numpy types if needed
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
# Let the base class default method raise the TypeError for other types
return super().default(o)


class FrozenDict(OrderedDict):

Expand Down Expand Up @@ -596,14 +599,14 @@ def to_json_saveable(value):
config_dict.pop("quant", None)
keys_to_remove = []
for key, value in config_dict.items():
# Check the type of the value by its class name to avoid import issues
if type(value).__name__ == 'Rngs':
keys_to_remove.append(key)
# Check the type of the value by its class name to avoid import issues
if type(value).__name__ == "Rngs":
keys_to_remove.append(key)

if keys_to_remove:
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
for key in keys_to_remove:
config_dict.pop(key)
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
for key in keys_to_remove:
config_dict.pop(key)

try:

Expand Down
53 changes: 29 additions & 24 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,47 @@
from maxdiffusion.utils import export_to_video
from google.cloud import storage


def upload_video_to_gcs(output_dir: str, video_path: str):
"""
Uploads a local video file to a specified Google Cloud Storage bucket.
"""
try:
path_without_scheme = output_dir.removeprefix("gs://")
parts = path_without_scheme.split('/', 1)
bucket_name = parts[0]
folder_name = parts[1] if len(parts) > 1 else ''
"""
Uploads a local video file to a specified Google Cloud Storage bucket.
"""
try:
path_without_scheme = output_dir.removeprefix("gs://")
parts = path_without_scheme.split("/", 1)
bucket_name = parts[0]
folder_name = parts[1] if len(parts) > 1 else ""

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)

source_file_path = f"./{video_path}"
destination_blob_name = os.path.join(folder_name, "videos", video_path)
source_file_path = f"./{video_path}"
destination_blob_name = os.path.join(folder_name, "videos", video_path)

blob = bucket.blob(destination_blob_name)
blob = bucket.blob(destination_blob_name)

max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
blob.upload_from_filename(source_file_path)
max_logging.log(f"Upload complete {source_file_path}.")
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
blob.upload_from_filename(source_file_path)
max_logging.log(f"Upload complete {source_file_path}.")

except Exception as e:
max_logging.log(f"An error occurred: {e}")

except Exception as e:
max_logging.log(f"An error occurred: {e}")

def delete_file(file_path: str):
if os.path.exists(file_path):
try:
os.remove(file_path)
max_logging.log(f"Successfully deleted file: {file_path}")
except OSError as e:
max_logging.log(f"Error deleting file '{file_path}': {e}")
try:
os.remove(file_path)
max_logging.log(f"Successfully deleted file: {file_path}")
except OSError as e:
max_logging.log(f"Error deleting file '{file_path}': {e}")
else:
max_logging.log(f"The file '{file_path}' does not exist.")
max_logging.log(f"The file '{file_path}' does not exist.")


jax.config.update("jax_use_shardy_partitioner", True)


def inference_generate_video(config, pipeline, filename_prefix=""):
s0 = time.perf_counter()
prompt = [config.prompt] * config.global_batch_size_to_train_on
Expand Down Expand Up @@ -88,6 +92,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
delete_file(f"./{video_path}")
return


def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
Expand Down
74 changes: 49 additions & 25 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,18 @@ def make_tf_iterator(
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def _make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description_fn,
prepare_sample_fn,
dataset_path,
is_training: bool,
):
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
Expand All @@ -93,10 +102,10 @@ def _make_tfrecord_iterator(
# Determine whether to use the "cached" dataset, which requires externally
# provided parsing functions, or the default one with its internal parsing logic.
make_cached_tfrecord_iterator = (
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
)

feature_description = {
Expand All @@ -121,42 +130,47 @@ def prepare_sample(features):
if not is_training:
num_eval_samples = 0
for _ in ds:
num_eval_samples += 1
num_eval_samples += 1

remainder = num_eval_samples % global_batch_size
if remainder != 0:
num_to_pad = global_batch_size - remainder
# Create a dataset of padding samples from the beginning
padding_ds = ds.take(num_to_pad)
# Add the padding samples to the end
ds = ds.concatenate(padding_ds)
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
num_to_pad = global_batch_size - remainder
# Create a dataset of padding samples from the beginning
padding_ds = ds.take(num_to_pad)
# Add the padding samples to the end
ds = ds.concatenate(padding_ds)
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")

used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
ds = (
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
)
if is_training:
ds = (
ds.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
ds.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)
# For Evaluation
else:
ds = (
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
.prefetch(AUTOTUNE)
)
ds = ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False).prefetch(AUTOTUNE)

iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
return iter


def make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
is_training,
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
Expand All @@ -165,4 +179,14 @@ def make_tfrecord_iterator(
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
# TODO: refactor to support evaluation on all dataset format.
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)
return _make_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
dataset_path,
is_training,
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def make_data_iterator(
global_batch_size,
feature_description,
prepare_sample_fn,
is_training
is_training,
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
Expand Down
Loading
Loading