Skip to content

Commit 043f826

Browse files
authored
Add dropout (#240)
* adds dropout * refactors shard_map to include weights layer.
1 parent 34a9134 commit 043f826

13 files changed

Lines changed: 292 additions & 153 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ jit_initializers: True
5656
# Set true to load weights from pytorch
5757
from_pt: True
5858
split_head_dim: True
59-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
59+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6060
flash_min_seq_length: 4096
61+
dropout: 0.1
6162

6263
flash_block_sizes: {}
6364
# Use on v6e
@@ -193,8 +194,14 @@ enable_data_shuffling: True
193194
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
194195
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
195196
# except for ones that involve batch dimension - that means that all attention and projection
196-
# layers will have gradient checkpoint, but not the backward with respect to the parameters
197+
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
198+
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
199+
# CUSTOM - set names to offload and save.
197200
remat_policy: "NONE"
201+
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
202+
# xq_out, xk_out, ffn_activation
203+
names_which_can_be_saved: []
204+
names_which_can_be_offloaded: []
198205

199206
# checkpoint every number of samples, -1 means don't checkpoint.
200207
checkpoint_every: -1

src/maxdiffusion/configuration_utils.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,24 @@
4747

4848
_re_configuration_file = re.compile(r"config\.(.*)\.json")
4949

50+
5051
class CustomEncoder(json.JSONEncoder):
51-
"""
52-
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
53-
"""
54-
def default(self, o):
55-
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
56-
if isinstance(o, type(jnp.dtype('bfloat16'))):
57-
return str(o)
58-
# Add fallbacks for other numpy types if needed
59-
if isinstance(o, np.integer):
60-
return int(o)
61-
if isinstance(o, np.floating):
62-
return float(o)
63-
# Let the base class default method raise the TypeError for other types
64-
return super().default(o)
52+
"""
53+
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
54+
"""
55+
56+
def default(self, o):
57+
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
58+
if isinstance(o, type(jnp.dtype("bfloat16"))):
59+
return str(o)
60+
# Add fallbacks for other numpy types if needed
61+
if isinstance(o, np.integer):
62+
return int(o)
63+
if isinstance(o, np.floating):
64+
return float(o)
65+
# Let the base class default method raise the TypeError for other types
66+
return super().default(o)
67+
6568

6669
class FrozenDict(OrderedDict):
6770

@@ -596,14 +599,14 @@ def to_json_saveable(value):
596599
config_dict.pop("quant", None)
597600
keys_to_remove = []
598601
for key, value in config_dict.items():
599-
# Check the type of the value by its class name to avoid import issues
600-
if type(value).__name__ == 'Rngs':
601-
keys_to_remove.append(key)
602+
# Check the type of the value by its class name to avoid import issues
603+
if type(value).__name__ == "Rngs":
604+
keys_to_remove.append(key)
602605

603606
if keys_to_remove:
604-
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
605-
for key in keys_to_remove:
606-
config_dict.pop(key)
607+
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
608+
for key in keys_to_remove:
609+
config_dict.pop(key)
607610

608611
try:
609612

src/maxdiffusion/generate_wan.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,47 @@
2222
from maxdiffusion.utils import export_to_video
2323
from google.cloud import storage
2424

25+
2526
def upload_video_to_gcs(output_dir: str, video_path: str):
26-
"""
27-
Uploads a local video file to a specified Google Cloud Storage bucket.
28-
"""
29-
try:
30-
path_without_scheme = output_dir.removeprefix("gs://")
31-
parts = path_without_scheme.split('/', 1)
32-
bucket_name = parts[0]
33-
folder_name = parts[1] if len(parts) > 1 else ''
27+
"""
28+
Uploads a local video file to a specified Google Cloud Storage bucket.
29+
"""
30+
try:
31+
path_without_scheme = output_dir.removeprefix("gs://")
32+
parts = path_without_scheme.split("/", 1)
33+
bucket_name = parts[0]
34+
folder_name = parts[1] if len(parts) > 1 else ""
3435

35-
storage_client = storage.Client()
36-
bucket = storage_client.bucket(bucket_name)
36+
storage_client = storage.Client()
37+
bucket = storage_client.bucket(bucket_name)
3738

38-
source_file_path = f"./{video_path}"
39-
destination_blob_name = os.path.join(folder_name, "videos", video_path)
39+
source_file_path = f"./{video_path}"
40+
destination_blob_name = os.path.join(folder_name, "videos", video_path)
4041

41-
blob = bucket.blob(destination_blob_name)
42+
blob = bucket.blob(destination_blob_name)
4243

43-
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
44-
blob.upload_from_filename(source_file_path)
45-
max_logging.log(f"Upload complete {source_file_path}.")
44+
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
45+
blob.upload_from_filename(source_file_path)
46+
max_logging.log(f"Upload complete {source_file_path}.")
47+
48+
except Exception as e:
49+
max_logging.log(f"An error occurred: {e}")
4650

47-
except Exception as e:
48-
max_logging.log(f"An error occurred: {e}")
4951

5052
def delete_file(file_path: str):
5153
if os.path.exists(file_path):
52-
try:
53-
os.remove(file_path)
54-
max_logging.log(f"Successfully deleted file: {file_path}")
55-
except OSError as e:
56-
max_logging.log(f"Error deleting file '{file_path}': {e}")
54+
try:
55+
os.remove(file_path)
56+
max_logging.log(f"Successfully deleted file: {file_path}")
57+
except OSError as e:
58+
max_logging.log(f"Error deleting file '{file_path}': {e}")
5759
else:
58-
max_logging.log(f"The file '{file_path}' does not exist.")
60+
max_logging.log(f"The file '{file_path}' does not exist.")
61+
5962

6063
jax.config.update("jax_use_shardy_partitioner", True)
6164

65+
6266
def inference_generate_video(config, pipeline, filename_prefix=""):
6367
s0 = time.perf_counter()
6468
prompt = [config.prompt] * config.global_batch_size_to_train_on
@@ -88,6 +92,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
8892
delete_file(f"./{video_path}")
8993
return
9094

95+
9196
def run(config, pipeline=None, filename_prefix=""):
9297
print("seed: ", config.seed)
9398
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,18 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81+
8182
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
8283
def _make_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
84+
config,
85+
dataloading_host_index,
86+
dataloading_host_count,
87+
mesh,
88+
global_batch_size,
89+
feature_description_fn,
90+
prepare_sample_fn,
91+
dataset_path,
92+
is_training: bool,
8493
):
8594
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
8695
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
@@ -93,10 +102,10 @@ def _make_tfrecord_iterator(
93102
# Determine whether to use the "cached" dataset, which requires externally
94103
# provided parsing functions, or the default one with its internal parsing logic.
95104
make_cached_tfrecord_iterator = (
96-
config.cache_latents_text_encoder_outputs
97-
and is_dataset_dir_valid
98-
and "load_tfrecord_cached" in config.get_keys()
99-
and config.load_tfrecord_cached
105+
config.cache_latents_text_encoder_outputs
106+
and is_dataset_dir_valid
107+
and "load_tfrecord_cached" in config.get_keys()
108+
and config.load_tfrecord_cached
100109
)
101110

102111
feature_description = {
@@ -121,42 +130,47 @@ def prepare_sample(features):
121130
if not is_training:
122131
num_eval_samples = 0
123132
for _ in ds:
124-
num_eval_samples += 1
133+
num_eval_samples += 1
125134

126135
remainder = num_eval_samples % global_batch_size
127136
if remainder != 0:
128-
num_to_pad = global_batch_size - remainder
129-
# Create a dataset of padding samples from the beginning
130-
padding_ds = ds.take(num_to_pad)
131-
# Add the padding samples to the end
132-
ds = ds.concatenate(padding_ds)
133-
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
137+
num_to_pad = global_batch_size - remainder
138+
# Create a dataset of padding samples from the beginning
139+
padding_ds = ds.take(num_to_pad)
140+
# Add the padding samples to the end
141+
ds = ds.concatenate(padding_ds)
142+
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
134143

135144
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136145
ds = (
137-
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
138-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
139-
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
146+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
147+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
148+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
140149
)
141150
if is_training:
142151
ds = (
143-
ds.shuffle(global_batch_size * 10)
144-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
145-
.repeat(-1)
146-
.prefetch(AUTOTUNE)
152+
ds.shuffle(global_batch_size * 10)
153+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
154+
.repeat(-1)
155+
.prefetch(AUTOTUNE)
147156
)
148157
# For Evaluation
149158
else:
150-
ds = (
151-
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
152-
.prefetch(AUTOTUNE)
153-
)
159+
ds = ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False).prefetch(AUTOTUNE)
154160

155161
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
156162
return iter
157163

164+
158165
def make_tfrecord_iterator(
159-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
166+
config,
167+
dataloading_host_index,
168+
dataloading_host_count,
169+
mesh,
170+
global_batch_size,
171+
feature_description,
172+
prepare_sample_fn,
173+
is_training,
160174
):
161175
"""Iterator for TFRecord format. For Laion dataset,
162176
check out preparation script
@@ -165,4 +179,14 @@ def make_tfrecord_iterator(
165179
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166180
# TODO: refactor to support evaluation on all dataset format.
167181
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
168-
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)
182+
return _make_tfrecord_iterator(
183+
config,
184+
dataloading_host_index,
185+
dataloading_host_count,
186+
mesh,
187+
global_batch_size,
188+
feature_description,
189+
prepare_sample_fn,
190+
dataset_path,
191+
is_training,
192+
)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def make_data_iterator(
107107
global_batch_size,
108108
feature_description,
109109
prepare_sample_fn,
110-
is_training
110+
is_training,
111111
)
112112
else:
113113
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

0 commit comments

Comments
 (0)