Skip to content

Commit 220f24b

Browse files
committed
adds pusav1 video dataset.
1 parent 2b8549a commit 220f24b

1 file changed

Lines changed: 131 additions & 0 deletions

File tree

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Prepare tfrecords with latents and text embeddings preprocessed.
19+
1. Download the dataset
20+
"""
21+
22+
import os
23+
import functools
24+
from absl import app
25+
from typing import Sequence, Union, List
26+
from datasets import load_dataset
27+
import csv
28+
import numpy as np
29+
import jax
30+
import jax.numpy as jnp
31+
from jax.sharding import Mesh
32+
from maxdiffusion import pyconfig, max_utils
33+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
34+
from maxdiffusion.video_processor import VideoProcessor
35+
36+
import torch
37+
import tensorflow as tf
38+
39+
40+
def image_feature(value):
41+
"""Returns a bytes_list from a string / byte."""
42+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
43+
44+
45+
def bytes_feature(value):
46+
"""Returns a bytes_list from a string / byte."""
47+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
48+
49+
50+
def float_feature(value):
51+
"""Returns a float_list from a float / double."""
52+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
53+
54+
55+
def int64_feature(value):
56+
"""Returns an int64_list from a bool / enum / int / uint."""
57+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
58+
59+
60+
def float_feature_list(value):
61+
"""Returns a list of float_list from a float / double."""
62+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
63+
64+
65+
def create_example(latent, hidden_states):
66+
latent = tf.io.serialize_tensor(latent)
67+
hidden_states = tf.io.serialize_tensor(hidden_states)
68+
feature = {
69+
"latents": bytes_feature(latent),
70+
"encoder_hidden_states": bytes_feature(hidden_states),
71+
}
72+
example = tf.train.Example(features=tf.train.Features(feature=feature))
73+
return example.SerializeToString()
74+
75+
def generate_dataset(config):
76+
77+
tfrecords_dir = config.tfrecords_dir
78+
if not os.path.exists(tfrecords_dir):
79+
os.makedirs(tfrecords_dir)
80+
81+
tf_rec_num = 0
82+
no_records_per_shard = config.no_records_per_shard
83+
global_record_count = 0
84+
writer = tf.io.TFRecordWriter(
85+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
86+
)
87+
shard_record_count = 0
88+
89+
# Load dataset
90+
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
91+
with open(metadata_path, 'r', newline='') as file:
92+
# Create a csv.reader object
93+
csv_reader = csv.reader(file)
94+
next(csv_reader)
95+
96+
# If your CSV has a header row, you can skip it
97+
# next(csv_reader, None)
98+
99+
# Iterate over each row in the CSV file
100+
for row in csv_reader:
101+
video_name = row[0]
102+
pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth")
103+
loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
104+
prompt_embeds = loaded_state_dict["prompt_emb"]["context"]
105+
latent = loaded_state_dict["latents"]
106+
# Format we want(4, 16, 1, 64, 64)
107+
latent = jnp.array(latent.float().numpy(), dtype=config.weights_dtype)
108+
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=config.weights_dtype)
109+
writer.write(create_example(latent, prompt_embeds))
110+
shard_record_count += 1
111+
global_record_count += 1
112+
113+
if shard_record_count >= no_records_per_shard:
114+
writer.close()
115+
tf_rec_num += 1
116+
writer = tf.io.TFRecordWriter(
117+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
118+
)
119+
shard_record_count = 0
120+
121+
def run(config):
122+
generate_dataset(config)
123+
124+
125+
def main(argv: Sequence[str]) -> None:
126+
pyconfig.initialize(argv)
127+
run(pyconfig.config)
128+
129+
130+
if __name__ == "__main__":
131+
app.run(main)

0 commit comments

Comments
 (0)