tf.contrib.gan.stargan_loss(
model,
generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(tfgan_losses_impl.\n wasserstein_generator_loss),
discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(tfgan_losses_impl.\n wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True
)
Defined in tensorflow/contrib/gan/python/train.py
.
StarGAN Loss.
The four major part can be found here: http://screen/tMRMBAohDYG.
Args:
model
: (StarGAN) Model output of the stargan_model() function call.generator_loss_fn
: The loss function on the generator. Takes aStarGANModel
named tuple.discriminator_loss_fn
: The loss function on the discriminator. Takes aStarGANModel
namedtuple.gradient_penalty_weight
: (float) Gradient penalty weight. Default to 10 per the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to turn off gradient penalty.gradient_penalty_epsilon
: (float) A small positive number added for numerical stability when computing the gradient norm.gradient_penalty_target
: (float, or tf.floatTensor
) The target value of gradient norm. Defaults to 1.0.gradient_penalty_one_sided
: (bool) IfTrue
, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults toFalse
.reconstruction_loss_fn
: The reconstruction loss function. Default to L1-norm and the function must conform to thetf.losses
API.reconstruction_loss_weight
: Reconstruction loss weight. Default to 10.0.classification_loss_fn
: The loss function on the discriminator's ability to classify domain of the input. Default to one-hot softmax cross entropy loss, and the function must conform to thetf.losses
API.classification_loss_weight
: (float) Classification loss weight. Default to 1.0.classification_one_hot
: (bool) If the label is one hot representation. Default to True. If False, classification classification_loss_fn need to be sigmoid cross entropy loss instead.add_summaries
: (bool) Add the loss to the summary
Returns:
GANLoss namedtuple where we have generator loss and discriminator loss.
Raises:
ValueError
: If input StarGANModel.input_data_domain_label does not have rank 2, or dimension 2 is not defined.