tf.ragged.segment_ids_to_row_splits

View source on GitHub

Generates the RaggedTensor row_splits corresponding to a segmentation.

tf.ragged.segment_ids_to_row_splits(
    segment_ids, num_segments=None, out_type=None, name=None
)

Returns an integer vector splits, where splits[0] = 0 and splits[i] = splits[i-1] + count(segment_ids==i). Example:

>>> print(tf.ragged.segment_ids_to_row_splits([0, 0, 0, 2, 2, 3, 4, 4, 4]))
tf.Tensor([0 3 3 5 6 9], shape=(6,), dtype=int64)

Args:

Returns:

A sorted 1-D integer Tensor, with shape=[num_segments + 1].