tf.contrib.gan.stargan_loss(
model,
generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(tfgan_losses_impl.\n wasserstein_generator_loss),
discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(tfgan_losses_impl.\n wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True
)
Defined in tensorflow/contrib/gan/python/train.py.
StarGAN Loss.
The four major part can be found here: http://screen/tMRMBAohDYG.
Args:
model: (StarGAN) Model output of the stargan_model() function call.generator_loss_fn: The loss function on the generator. Takes aStarGANModelnamed tuple.discriminator_loss_fn: The loss function on the discriminator. Takes aStarGANModelnamedtuple.gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to turn off gradient penalty.gradient_penalty_epsilon: (float) A small positive number added for numerical stability when computing the gradient norm.gradient_penalty_target: (float, or tf.floatTensor) The target value of gradient norm. Defaults to 1.0.gradient_penalty_one_sided: (bool) IfTrue, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults toFalse.reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm and the function must conform to thetf.lossesAPI.reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.classification_loss_fn: The loss function on the discriminator's ability to classify domain of the input. Default to one-hot softmax cross entropy loss, and the function must conform to thetf.lossesAPI.classification_loss_weight: (float) Classification loss weight. Default to 1.0.classification_one_hot: (bool) If the label is one hot representation. Default to True. If False, classification classification_loss_fn need to be sigmoid cross entropy loss instead.add_summaries: (bool) Add the loss to the summary
Returns:
GANLoss namedtuple where we have generator loss and discriminator loss.
Raises:
ValueError: If input StarGANModel.input_data_domain_label does not have rank 2, or dimension 2 is not defined.