tf.estimator.BinaryClassHead

View source on GitHub

Creates a Head for single label binary classification.

Inherits From: Head

tf.estimator.BinaryClassHead(
    weight_column=None, thresholds=None, label_vocabulary=None,
    loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, loss_fn=None,
    name=None
)

Uses sigmoid_cross_entropy_with_logits loss.

The head expects logits with shape [D0, D1, ... DN, 1]. In many applications, the shape is [batch_size, 1].

labels must be a dense Tensor with shape matching logits, namely [D0, D1, ... DN, 1]. If label_vocabulary given, labels must be a string Tensor with values from the vocabulary. If label_vocabulary is not given, labels must be float Tensor with values in the interval [0, 1].

If weight_column is specified, weights must be of shape [D0, D1, ... DN], or [D0, D1, ... DN, 1].

The loss is the weighted sum over the input dimensions. Namely, if the input labels have shape [batch_size, 1], the loss is the weighted sum over batch_size.

Also supports custom loss_fn. loss_fn takes (labels, logits) or (labels, logits, features, loss_reduction) as arguments and returns loss with shape [D0, D1, ... DN, 1]. loss_fn must support float labels with shape [D0, D1, ... DN, 1]. Namely, the head applies label_vocabulary to the input labels before passing them to loss_fn.

Args:

Attributes:

Methods

create_estimator_spec

View source

create_estimator_spec(
    features, mode, logits, labels=None, optimizer=None, trainable_variables=None,
    train_op_fn=None, update_ops=None, regularization_losses=None
)

Returns EstimatorSpec that a model_fn can return.

It is recommended to pass all args via name.

Args:

Returns:

EstimatorSpec.

loss

View source

loss(
    labels, logits, features=None, mode=None, regularization_losses=None
)

Returns regularized training loss. See base_head.Head for details.

metrics

View source

metrics(
    regularization_losses=None
)

Creates metrics. See base_head.Head for details.

predictions

View source

predictions(
    logits, keys=None
)

Return predictions based on keys. See base_head.Head for details.

Args:

Returns:

A dict of predictions.

update_metrics

View source

update_metrics(
    eval_metrics, features, logits, labels, regularization_losses=None
)

Updates eval metrics. See base_head.Head for details.