View source on GitHub |
Writes a dataset to a TFRecord file.
tf.data.experimental.TFRecordWriter(
filename, compression_type=None
)
The elements of the dataset must be scalar strings. To serialize dataset
elements as strings, you can use the tf.io.serialize_tensor
function.
dataset = tf.data.Dataset.range(3)
dataset = dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
writer.write(dataset)
To read back the elements, use TFRecordDataset
.
dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
To shard a dataset
across multiple TFRecord files:
dataset = ... # dataset to be written
def reduce_func(key, dataset):
filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(dataset.map(lambda _, x: x))
return tf.data.Dataset.from_tensors(filename)
dataset = dataset.enumerate()
dataset = dataset.apply(tf.data.experimental.group_by_window(
lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
))
filename
: a string path indicating where to write the TFRecord data.compression_type
: (Optional.) a string indicating what type of compression
to use when writing the file. See tf.io.TFRecordCompressionType
for
what types of compression are available. Defaults to None
.write
write(
dataset
)
Writes a dataset to a TFRecord file.
An operation that writes the content of the specified dataset to the file specified in the constructor.
If the file exists, it will be overwritten.
dataset
: a tf.data.Dataset
whose elements are to be written to a fileIn graph mode, this returns an operation which when executed performs the write. In eager mode, the write is performed by the method itself and there is no return value.
Raises
TypeError: if dataset
is not a tf.data.Dataset
.
TypeError: if the elements produced by the dataset are not scalar strings.