tf.contrib.gan.estimator.GANHead

Class GANHead

Aliases:

  • Class tf.contrib.gan.estimator.GANHead
  • Class tf.contrib.gan.estimator.head.GANHead

Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py.

Head for a GAN.

__init__

__init__(
    generator_loss_fn,
    discriminator_loss_fn,
    generator_optimizer,
    discriminator_optimizer,
    use_loss_summaries=True,
    get_hooks_fn=None,
    get_eval_metric_ops_fn=None,
    name=None
)

Head for GAN training. (deprecated)

Args:

  • generator_loss_fn: A TFGAN loss function for the generator. Takes a GANModel and returns a scalar.
  • discriminator_loss_fn: Same as generator_loss_fn, but for the discriminator.
  • generator_optimizer: The optimizer for generator updates.
  • discriminator_optimizer: Same as generator_optimizer, but for the discriminator updates.
  • use_loss_summaries: If True, add loss summaries. If False, does not. If None, uses defaults.
  • get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to train.get_sequential_train_hooks()
  • get_eval_metric_ops_fn: A function that takes a GANModel, and returns a dict of metric results keyed by name. The output of this function is passed into tf.estimator.EstimatorSpec during evaluation.
  • name: name of the head. If provided, summary and metrics keys will be suffixed by "/" + name.

Properties

logits_dimension

Size of the last dimension of the logits Tensor.

Typically, logits is of shape [batch_size, logits_dimension].

Returns:

The expected size of the logits tensor.

name

The name of this head.

Returns:

A string.

Methods

tf.contrib.gan.estimator.GANHead.create_estimator_spec

create_estimator_spec(
    features,
    mode,
    logits,
    labels=None,
    train_op_fn=tf.contrib.gan.gan_train_ops
)

Returns EstimatorSpec that a model_fn can return.

See Head for more details.

Args:

  • features: Must be None.
  • mode: Estimator's ModeKeys.
  • logits: A GANModel tuple.
  • labels: Must be None.
  • train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, and discriminator optimizer, and returns a GANTrainOps tuple. For example, this function can come from TFGAN's train.py library, or can be custom.

Returns:

EstimatorSpec.

Raises:

  • ValueError: If features isn't None.
  • ValueError: If train_op_fn isn't provided in train mode.

tf.contrib.gan.estimator.GANHead.create_loss

create_loss(
    features,
    mode,
    logits,
    labels
)

Returns a GANLoss tuple from the provided GANModel.

See Head for more details.

Args:

  • features: Input dict of Tensor objects. Unused.
  • mode: Estimator's ModeKeys.
  • logits: A GANModel tuple.
  • labels: Must be None.

Returns:

A GANLoss tuple.