tf.estimator.MultiHead

View source on GitHub

Creates a Head for multi-objective learning.

Inherits From: Head

tf.estimator.MultiHead(
    heads, head_weights=None
)

This class merges the output of multiple Head objects. Specifically:

Usage:

# In `input_fn`, specify labels as a dict keyed by head name:
def input_fn():
  features = ...
  labels1 = ...
  labels2 = ...
  return features, {'head1.name': labels1, 'head2.name': labels2}

# In `model_fn`, specify logits as a dict keyed by head name:
def model_fn(features, labels, mode):
  # Create simple heads and specify head name.
  head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
  head2 = tf.estimator.BinaryClassHead(name='head2')
  # Create MultiHead from two simple heads.
  head = tf.estimator.MultiHead([head1, head2])
  # Create logits for each head, and combine them into a dict.
  logits1, logits2 = logit_fn()
  logits = {'head1.name': logits1, 'head2.name': logits2}
  # Return the merged EstimatorSpec
  return head.create_estimator_spec(..., logits=logits, ...)

# Create an estimator with this model_fn.
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=input_fn)

Also supports logits as a Tensor of shape [D0, D1, ... DN, logits_dimension]. It will split the Tensor along the last dimension and distribute it appropriately among the heads. E.g.:

Input logits.

logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],

dtype=np.float32)

Suppose head1.logits_dimension = 2 and head2.logits_dimension = 3. After

splitting, the result is:

logits_dict = {'head1_name': [[-1., 1.], [-1.5, 1.]],

'head2_name':  [[2., -2., 2.], [-3., 2., -2.]]}

Usage:

def model_fn(features, labels, mode):
  # Create simple heads and specify head name.
  head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
  head2 = tf.estimator.BinaryClassHead(name='head2')
  # Create multi-head from two simple heads.
  head = tf.estimator.MultiHead([head1, head2])
  # Create logits for the multihead. The result of logits is a `Tensor`.
  logits = logit_fn(logits_dimension=head.logits_dimension)
  # Return the merged EstimatorSpec
  return head.create_estimator_spec(..., logits=logits, ...)

Args:

Attributes:

Methods

create_estimator_spec

View source

create_estimator_spec(
    features, mode, logits, labels=None, optimizer=None, trainable_variables=None,
    train_op_fn=None, update_ops=None, regularization_losses=None
)

Returns a model_fn.EstimatorSpec.

Args:

Returns:

A model_fn.EstimatorSpec instance.

Raises:

loss

View source

loss(
    labels, logits, features=None, mode=None, regularization_losses=None
)

Returns regularized training loss. See base_head.Head for details.

metrics

View source

metrics(
    regularization_losses=None
)

Creates metrics. See base_head.Head for details.

predictions

View source

predictions(
    logits, keys=None
)

Create predictions. See base_head.Head for details.

update_metrics

View source

update_metrics(
    eval_metrics, features, logits, labels, regularization_losses=None
)

Updates eval metrics. See base_head.Head for details.