View source on GitHub |
Hook to run evaluation in training without a checkpoint.
Inherits From: SessionRunHook
tf.estimator.experimental.InMemoryEvaluatorHook(
estimator, input_fn, steps=None, hooks=None, name=None, every_n_iter=100
)
def train_input_fn():
...
return train_dataset
def eval_input_fn():
...
return eval_dataset
estimator = tf.estimator.DNNClassifier(...)
evaluator = tf.estimator.experimental.InMemoryEvaluatorHook(
estimator, eval_input_fn)
estimator.train(train_input_fn, hooks=[evaluator])
Current limitations of this approach are:
estimator
: A tf.estimator.Estimator
instance to call evaluate.input_fn
: Equivalent to the input_fn
arg to estimator.evaluate
. A
function that constructs the input data for evaluation.
See Creating input functions
for more information. The function should construct and return one of
the following:
Dataset
object must be a
tuple (features, labels) with same constraints as below.features
is a Tensor
or a
dictionary of string feature name to Tensor
and labels
is a
Tensor
or a dictionary of string label name to Tensor
. Both
features
and labels
are consumed by model_fn
. They should
satisfy the expectation of model_fn
from inputs.steps
: Equivalent to the steps
arg to estimator.evaluate
. Number of
steps for which to evaluate model. If None
, evaluates until input_fn
raises an end-of-input exception.
hooks
: Equivalent to the hooks
arg to estimator.evaluate
. List of
SessionRunHook
subclass instances. Used for callbacks inside the
evaluation call.
name
: Equivalent to the name
arg to estimator.evaluate
. Name of the
evaluation if user needs to run multiple evaluations on different data
sets, such as on training data vs test data. Metrics for different
evaluations are saved in separate folders, and appear separately in
tensorboard.
every_n_iter
: int
, runs the evaluator once every N training iteration.
ValueError
: if every_n_iter
is non-positive or it's not a single machine
trainingafter_create_session
after_create_session(
session, coord
)
Does first run which shows the eval metrics before training.
after_run
after_run(
run_context, run_values
)
Runs evaluator.
before_run
before_run(
run_context
)
Called before each call to run().
You can return from this call a SessionRunArgs
object indicating ops or
tensors to add to the upcoming run()
call. These ops/tensors will be run
together with the ops/tensors originally passed to the original run() call.
The run args you return can also contain feeds to be added to the run()
call.
The run_context
argument is a SessionRunContext
that provides
information about the upcoming run()
call: the originally requested
op/tensors, the TensorFlow Session.
At this point graph is finalized and you can not add ops.
run_context
: A SessionRunContext
object.None or a SessionRunArgs
object.
begin
begin()
Build eval graph and restoring op.
end
end(
session
)
Runs evaluator for final model.