View source on GitHub |
A transformation that batches ragged elements into tf.RaggedTensor
s.
tf.data.experimental.dense_to_ragged_batch(
batch_size, drop_remainder=False, row_splits_dtype=tf.dtypes.int64
)
This transformation combines multiple consecutive elements of the input dataset into a single element.
Like tf.data.Dataset.batch
, the components of the resulting element will
have an additional outer dimension, which will be batch_size
(or
N % batch_size
for the last element if batch_size
does not divide the
number of input elements N
evenly and drop_remainder
is False
). If
your program depends on the batches having the same outer dimension, you
should set the drop_remainder
argument to True
to prevent the smaller
batch from being produced.
Unlike tf.data.Dataset.batch
, the input elements to be batched may have
different shapes, and each batch will be encoded as a tf.RaggedTensor
.
Example:
>>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
>>> dataset = dataset.map(lambda x: tf.range(x))
>>> dataset = dataset.apply(
... tf.data.experimental.dense_to_ragged_batch(batch_size=2))
>>> for batch in dataset:
... print(batch)
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
batch_size
: A tf.int64
scalar tf.Tensor
, representing the number of
consecutive elements of this dataset to combine in a single batch.drop_remainder
: (Optional.) A tf.bool
scalar tf.Tensor
, representing
whether the last batch should be dropped in the case it has fewer than
batch_size
elements; the default behavior is not to drop the smaller
batch.row_splits_dtype
: The dtype that should be used for the row_splits
of any
new ragged tensors. Existing tf.RaggedTensor
elements do not have their
row_splits dtype changed.Dataset
: A Dataset
.