tf.contrib.gan.estimator.gan_head

Aliases:

  • tf.contrib.gan.estimator.gan_head
  • tf.contrib.gan.estimator.head.gan_head
tf.contrib.gan.estimator.gan_head(
    generator_loss_fn,
    discriminator_loss_fn,
    generator_optimizer,
    discriminator_optimizer,
    use_loss_summaries=True,
    get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
    get_eval_metric_ops_fn=None,
    name=None
)

Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py.

Creates a GANHead. (deprecated)

Args:

  • generator_loss_fn: A TFGAN loss function for the generator. Takes a GANModel and returns a scalar.
  • discriminator_loss_fn: Same as generator_loss_fn, but for the discriminator.
  • generator_optimizer: The optimizer for generator updates.
  • discriminator_optimizer: Same as generator_optimizer, but for the discriminator updates.
  • use_loss_summaries: If True, add loss summaries. If False, does not. If None, uses defaults.
  • get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks.
  • get_eval_metric_ops_fn: A function that takes a GANModel, and returns a dict of metric results keyed by name. The output of this function is passed into tf.estimator.EstimatorSpec during evaluation.
  • name: name of the head. If provided, summary and metrics keys will be suffixed by "/" + name.

Returns:

An instance of GANHead.