|
| 1 | +""" |
| 2 | + Copyright 2025 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +""" |
| 18 | +Prepare tfrecords with latents and text embeddings preprocessed. |
| 19 | +1. Download the dataset |
| 20 | +""" |
| 21 | + |
| 22 | +import os |
| 23 | +import functools |
| 24 | +from absl import app |
| 25 | +from typing import Sequence, Union, List |
| 26 | +from datasets import load_dataset |
| 27 | +import csv |
| 28 | +import numpy as np |
| 29 | +import jax |
| 30 | +import jax.numpy as jnp |
| 31 | +from jax.sharding import Mesh |
| 32 | +from maxdiffusion import pyconfig, max_utils |
| 33 | +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline |
| 34 | +from maxdiffusion.video_processor import VideoProcessor |
| 35 | + |
| 36 | +import torch |
| 37 | +import tensorflow as tf |
| 38 | + |
| 39 | + |
| 40 | +def image_feature(value): |
| 41 | + """Returns a bytes_list from a string / byte.""" |
| 42 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) |
| 43 | + |
| 44 | + |
| 45 | +def bytes_feature(value): |
| 46 | + """Returns a bytes_list from a string / byte.""" |
| 47 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) |
| 48 | + |
| 49 | + |
| 50 | +def float_feature(value): |
| 51 | + """Returns a float_list from a float / double.""" |
| 52 | + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) |
| 53 | + |
| 54 | + |
| 55 | +def int64_feature(value): |
| 56 | + """Returns an int64_list from a bool / enum / int / uint.""" |
| 57 | + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) |
| 58 | + |
| 59 | + |
| 60 | +def float_feature_list(value): |
| 61 | + """Returns a list of float_list from a float / double.""" |
| 62 | + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) |
| 63 | + |
| 64 | + |
| 65 | +def create_example(latent, hidden_states): |
| 66 | + latent = tf.io.serialize_tensor(latent) |
| 67 | + hidden_states = tf.io.serialize_tensor(hidden_states) |
| 68 | + feature = { |
| 69 | + "latents": bytes_feature(latent), |
| 70 | + "encoder_hidden_states": bytes_feature(hidden_states), |
| 71 | + } |
| 72 | + example = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 73 | + return example.SerializeToString() |
| 74 | + |
| 75 | +def generate_dataset(config): |
| 76 | + |
| 77 | + tfrecords_dir = config.tfrecords_dir |
| 78 | + if not os.path.exists(tfrecords_dir): |
| 79 | + os.makedirs(tfrecords_dir) |
| 80 | + |
| 81 | + tf_rec_num = 0 |
| 82 | + no_records_per_shard = config.no_records_per_shard |
| 83 | + global_record_count = 0 |
| 84 | + writer = tf.io.TFRecordWriter( |
| 85 | + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) |
| 86 | + ) |
| 87 | + shard_record_count = 0 |
| 88 | + |
| 89 | + # Load dataset |
| 90 | + metadata_path = os.path.join(config.train_data_dir, "metadata.csv") |
| 91 | + with open(metadata_path, 'r', newline='') as file: |
| 92 | + # Create a csv.reader object |
| 93 | + csv_reader = csv.reader(file) |
| 94 | + next(csv_reader) |
| 95 | + |
| 96 | + # If your CSV has a header row, you can skip it |
| 97 | + # next(csv_reader, None) |
| 98 | + |
| 99 | + # Iterate over each row in the CSV file |
| 100 | + for row in csv_reader: |
| 101 | + video_name = row[0] |
| 102 | + pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth") |
| 103 | + loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu')) |
| 104 | + prompt_embeds = loaded_state_dict["prompt_emb"]["context"] |
| 105 | + latent = loaded_state_dict["latents"] |
| 106 | + # Format we want(4, 16, 1, 64, 64) |
| 107 | + latent = jnp.array(latent.float().numpy(), dtype=config.weights_dtype) |
| 108 | + prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=config.weights_dtype) |
| 109 | + writer.write(create_example(latent, prompt_embeds)) |
| 110 | + shard_record_count += 1 |
| 111 | + global_record_count += 1 |
| 112 | + |
| 113 | + if shard_record_count >= no_records_per_shard: |
| 114 | + writer.close() |
| 115 | + tf_rec_num += 1 |
| 116 | + writer = tf.io.TFRecordWriter( |
| 117 | + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) |
| 118 | + ) |
| 119 | + shard_record_count = 0 |
| 120 | + |
| 121 | +def run(config): |
| 122 | + generate_dataset(config) |
| 123 | + |
| 124 | + |
| 125 | +def main(argv: Sequence[str]) -> None: |
| 126 | + pyconfig.initialize(argv) |
| 127 | + run(pyconfig.config) |
| 128 | + |
| 129 | + |
| 130 | +if __name__ == "__main__": |
| 131 | + app.run(main) |
0 commit comments