Defined in tensorflow/contrib/estimator/__init__.py
.
estimator python module.
Importing from tensorflow.python.estimator is unsupported and will soon break!
Classes
class RNNClassifier
: A classifier for TensorFlow RNN models.
class RNNEstimator
: An Estimator for TensorFlow RNN models with user-specified head.
class SavedModelEstimator
: Create an Estimator from a SavedModel.
class TowerOptimizer
: Gathers gradients from all towers and reduces them in the last one.
Functions
DNNClassifierWithLayerAnnotations(...)
: A classifier for TensorFlow DNN models with layer annotations.
DNNRegressorWithLayerAnnotations(...)
: A regressor for TensorFlow DNN models with layer annotations.
add_metrics(...)
: Creates a new tf.estimator.Estimator
which has given metrics.
binary_classification_head(...)
: Creates a _Head
for single label binary classification.
boosted_trees_classifier_train_in_memory(...)
: Trains a boosted tree classifier with in memory dataset.
boosted_trees_regressor_train_in_memory(...)
: Trains a boosted tree regressor with in memory dataset.
build_raw_supervised_input_receiver_fn(...)
: Build a supervised_input_receiver_fn for raw features and labels.
build_supervised_input_receiver_fn_from_input_fn(...)
: Get a function that returns a SupervisedInputReceiver matching an input_fn.
call_logit_fn(...)
: Calls logit_fn.
clip_gradients_by_norm(...)
: Returns an optimizer which clips gradients before applying them.
dnn_logit_fn_builder(...)
: Function builder for a dnn logit_fn.
export_all_saved_models(...)
: Exports requested train/eval/predict graphs as separate SavedModels. (deprecated)
export_saved_model_for_mode(...)
: Exports a single train/eval/predict graph as a SavedModel. (deprecated)
forward_features(...)
: Forward features to predictions dictionary.
linear_logit_fn_builder(...)
: Function builder for a linear logit_fn.
logistic_regression_head(...)
: Creates a _Head
for logistic regression.
make_early_stopping_hook(...)
: Creates early-stopping hook.
multi_class_head(...)
: Creates a _Head
for multi class classification.
multi_head(...)
: Creates a _Head
for multi-objective learning.
multi_label_head(...)
: Creates a _Head
for multi-label classification.
poisson_regression_head(...)
: Creates a _Head
for poisson regression using tf.nn.log_poisson_loss
.
read_eval_metrics(...)
: Helper to read eval metrics from eval summary files.
regression_head(...)
: Creates a _Head
for regression using the mean_squared_error
loss.
replicate_model_fn(...)
: Replicate Estimator.model_fn
over GPUs. (deprecated)
stop_if_higher_hook(...)
: Creates hook to stop if the given metric is higher than the threshold.
stop_if_lower_hook(...)
: Creates hook to stop if the given metric is lower than the threshold.
stop_if_no_decrease_hook(...)
: Creates hook to stop if metric does not decrease within given max steps.
stop_if_no_increase_hook(...)
: Creates hook to stop if metric does not increase within given max steps.