Skip to content

Commit c1fc2f5

Browse files
authored
fix hf unit test (#151)
1 parent ec4166e commit c1fc2f5

4 files changed

Lines changed: 36 additions & 7 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
jax>=0.4.30
22
jaxlib>=0.4.30
3-
grain-nightly
3+
grain-nightly==0.0.10
44
google-cloud-storage==2.17.0
55
absl-py
66
datasets

src/maxdiffusion/input_pipeline/_hf_data_processing.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
self.current_shard = dataloading_host_index
109109
self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard)
110110
self.data_iter = None
111+
self.out_of_data = False
111112

112113
def _check_shard_count(self):
113114
if self.n_shards < self.dataloading_host_count:
@@ -119,11 +120,15 @@ def _check_shard_count(self):
119120
self.n_shards = self.dataloading_host_count
120121

121122
def _update_shard(self):
122-
new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards
123-
max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}")
124-
self.current_shard = new_shard
125-
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
126-
self.data_iter = iter(self.dataset_shard)
123+
new_shard = self.current_shard + self.dataloading_host_count
124+
if new_shard < self.n_shards:
125+
max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}")
126+
self.current_shard = new_shard
127+
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
128+
self.data_iter = iter(self.dataset_shard)
129+
else:
130+
max_logging.log(f"Run out of shards on host {self.dataloading_host_index}, shard {new_shard} is not available")
131+
self.out_of_data = True
127132

128133
def __len__(self):
129134
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
@@ -138,6 +143,8 @@ def __getitem__(self, index):
138143

139144
while True:
140145
try:
146+
if self.out_of_data:
147+
return None
141148
data = next(self.data_iter)
142149
return data
143150
except StopIteration:

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
limitations under the License.
1515
"""
1616

17+
import io
18+
from PIL import Image
1719
import importlib
1820
import numpy as np
1921
import tensorflow as tf
@@ -72,6 +74,24 @@ def vae_apply(images, sample_rng, vae, vae_params):
7274
return latents
7375

7476

77+
def convert_dict_to_pil(image):
78+
"""
79+
Converts a dictionary containing image bytes to a PIL Image object.
80+
81+
Args:
82+
image_dict: A dictionary with keys 'bytes' (image data) and 'path' (optional).
83+
84+
Returns:
85+
A PIL Image object.
86+
"""
87+
if isinstance(image, dict):
88+
image_bytes = image["bytes"]
89+
image_stream = io.BytesIO(image_bytes) # Create a BytesIO object
90+
pil_image = Image.open(image_stream) # Open the image from the stream
91+
return pil_image
92+
return image
93+
94+
7595
def transform_images(
7696
examples,
7797
image_column,
@@ -83,7 +103,7 @@ def transform_images(
83103
):
84104
"""Preprocess images to latents."""
85105
images = list(examples[image_column])
86-
images = [np.asarray(image) for image in images]
106+
images = [convert_dict_to_pil(image) for image in images]
87107
tensor_list = []
88108
for image in images:
89109
image = tf.image.resize(image, [image_resolution, image_resolution], method="bilinear", antialias=True)

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def user_init(raw_keys):
148148

149149
if "hf_train_files" in raw_keys and not raw_keys["hf_train_files"]:
150150
raw_keys["hf_train_files"] = None
151+
if "hf_access_token" in raw_keys and not raw_keys["hf_access_token"]:
152+
raw_keys["hf_access_token"] = None
151153

152154
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
153155

0 commit comments

Comments
 (0)