Class GANHead
Aliases:
- Class
tf.contrib.gan.estimator.GANHead
- Class
tf.contrib.gan.estimator.head.GANHead
Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py
.
Head
for a GAN.
__init__
__init__(
generator_loss_fn,
discriminator_loss_fn,
generator_optimizer,
discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=None,
get_eval_metric_ops_fn=None,
name=None
)
Head
for GAN training. (deprecated)
Args:
generator_loss_fn
: A TFGAN loss function for the generator. Takes aGANModel
and returns a scalar.discriminator_loss_fn
: Same asgenerator_loss_fn
, but for the discriminator.generator_optimizer
: The optimizer for generator updates.discriminator_optimizer
: Same asgenerator_optimizer
, but for the discriminator updates.use_loss_summaries
: IfTrue
, add loss summaries. IfFalse
, does not. IfNone
, uses defaults.get_hooks_fn
: A function that takes aGANTrainOps
tuple and returns a list of hooks. Defaults totrain.get_sequential_train_hooks()
get_eval_metric_ops_fn
: A function that takes aGANModel
, and returns a dict of metric results keyed by name. The output of this function is passed intotf.estimator.EstimatorSpec
during evaluation.name
: name of the head. If provided, summary and metrics keys will be suffixed by"/" + name
.
Properties
logits_dimension
Size of the last dimension of the logits Tensor
.
Typically, logits is of shape [batch_size, logits_dimension]
.
Returns:
The expected size of the logits
tensor.
name
The name of this head.
Returns:
A string.
Methods
tf.contrib.gan.estimator.GANHead.create_estimator_spec
create_estimator_spec(
features,
mode,
logits,
labels=None,
train_op_fn=tf.contrib.gan.gan_train_ops
)
Returns EstimatorSpec
that a model_fn can return.
See Head
for more details.
Args:
features
: Must beNone
.mode
: Estimator'sModeKeys
.logits
: A GANModel tuple.labels
: Must beNone
.train_op_fn
: Function that takes a GANModel, GANLoss, generator optimizer, and discriminator optimizer, and returns aGANTrainOps
tuple. For example, this function can come from TFGAN'strain.py
library, or can be custom.
Returns:
EstimatorSpec
.
Raises:
ValueError
: Iffeatures
isn'tNone
.ValueError
: Iftrain_op_fn
isn't provided in train mode.
tf.contrib.gan.estimator.GANHead.create_loss
create_loss(
features,
mode,
logits,
labels
)
Returns a GANLoss tuple from the provided GANModel.
See Head
for more details.
Args:
features
: Inputdict
ofTensor
objects. Unused.mode
: Estimator'sModeKeys
.logits
: A GANModel tuple.labels
: Must beNone
.
Returns:
A GANLoss tuple.