chainer.functions.gumbel_softmax

chainer.functions.gumbel_softmax(log_pi, tau=0.1, axis=1)[source]

Gumbel-Softmax sampling function.

This function draws samples yi from Gumbel-Softmax distribution,

yi=exp((gi+logπi)/τ)jexp((gj+logπj)/τ),

where τ is a temperature parameter and gi s are samples drawn from Gumbel distribution Gumbel(0,1)

See Categorical Reparameterization with Gumbel-Softmax.

Parameters
Returns

Output variable.

Return type

Variable