tf.random.categorical

View source on GitHub

Draws samples from a categorical distribution.

tf.random.categorical(
    logits, num_samples, dtype=None, seed=None, name=None
)

Example:

# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 5)

Args:

Returns:

The drawn samples of shape [batch_size, num_samples].