|
19 | 19 |
|
20 | 20 | import numpy as np |
21 | 21 |
|
22 | | -import tensorflow as tf |
23 | | - |
24 | 22 | import jax |
25 | 23 | import jax.numpy as jnp |
26 | 24 | from jax.sharding import PartitionSpec as P |
@@ -100,26 +98,19 @@ def reset(self): |
100 | 98 | @staticmethod |
101 | 99 | def get_place_holder_synthetic_data(config: pyconfig.HyperParameters): |
102 | 100 | """fill negative value in synthetic data""" |
103 | | - output = {} |
104 | | - output["inputs"] = tf.data.Dataset.from_tensor_slices( |
105 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
106 | | - ) |
107 | | - output["inputs_position"] = tf.data.Dataset.from_tensor_slices( |
108 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
109 | | - ) |
110 | | - output["inputs_segmentation"] = tf.data.Dataset.from_tensor_slices( |
111 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
112 | | - ) |
113 | | - output["targets"] = tf.data.Dataset.from_tensor_slices( |
114 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
115 | | - ) |
116 | | - output["targets_position"] = tf.data.Dataset.from_tensor_slices( |
117 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
118 | | - ) |
119 | | - output["targets_segmentation"] = tf.data.Dataset.from_tensor_slices( |
120 | | - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) |
121 | | - ) |
122 | | - dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types |
123 | | - dataset = dataset.repeat() |
124 | | - dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count()) |
125 | | - return dataset |
| 101 | + batch_size = config.global_batch_size_to_load // jax.process_count() |
| 102 | + neg_ones = np.full((batch_size, config.max_target_length), -1, dtype=np.int32) |
| 103 | + batch = { |
| 104 | + "inputs": neg_ones, |
| 105 | + "inputs_position": neg_ones, |
| 106 | + "inputs_segmentation": neg_ones, |
| 107 | + "targets": neg_ones, |
| 108 | + "targets_position": neg_ones, |
| 109 | + "targets_segmentation": neg_ones, |
| 110 | + } |
| 111 | + |
| 112 | + def infinite_iterator(): |
| 113 | + while True: |
| 114 | + yield batch |
| 115 | + |
| 116 | + return infinite_iterator() |
0 commit comments