tf.contrib.gan.losses.wargs.wasserstein_gradient_penalty(
real_data,
generated_data,
generator_inputs,
discriminator_fn,
discriminator_scope,
epsilon=1e-10,
target=1.0,
one_sided=False,
weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False
)
Defined in tensorflow/contrib/gan/python/losses/python/losses_impl.py
.
The gradient penalty for the Wasserstein discriminator loss.
See Improved Training of Wasserstein GANs
(https://arxiv.org/abs/1704.00028) for more details.
Args:
real_data
: Real data.generated_data
: Output of the generator.generator_inputs
: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator.discriminator_fn
: A discriminator function that conforms to TFGAN API.discriminator_scope
: If notNone
, reuse discriminators from this scope.epsilon
: A small positive number added for numerical stability when computing the gradient norm.target
: Optional Python number orTensor
indicating the target value of gradient norm. Defaults to 1.0.one_sided
: IfTrue
, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults toFalse
.weights
: OptionalTensor
whose rank is either 0, or the same rank asreal_data
andgenerated_data
, and must be broadcastable to them (i.e., all dimensions must be either1
, or the same as the corresponding dimension).scope
: The scope for the operations performed in computing the loss.loss_collection
: collection to which this loss will be added.reduction
: Atf.losses.Reduction
to apply to loss.add_summaries
: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on reduction
.
Raises:
ValueError
: If the rank of data Tensors is unknown.