Class TrainSpec
Configuration for the "train" part for the train_and_evaluate
call.
TrainSpec
determines the input data for the training, as well as the
duration. Optional hooks run at various stages of training.
__new__
@staticmethod
__new__(
cls,
input_fn,
max_steps=None,
hooks=None
)
Creates a validated TrainSpec
instance.
Args:
input_fn
: A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following:- A 'tf.data.Dataset' object: Outputs of
Dataset
object must be a tuple (features, labels) with same constraints as below. - A tuple (features, labels): Where features is a
Tensor
or a dictionary of string feature name toTensor
and labels is aTensor
or a dictionary of string label name toTensor
.
- A 'tf.data.Dataset' object: Outputs of
max_steps
: Int. Positive number of total steps for which to train model. IfNone
, train forever. The traininginput_fn
is not expected to generateOutOfRangeError
orStopIteration
exceptions. See thetrain_and_evaluate
stop condition section for details.hooks
: Iterable oftf.train.SessionRunHook
objects to run on all workers (including chief) during training.
Returns:
A validated TrainSpec
object.
Raises:
ValueError
: If any of the input arguments is invalid.TypeError
: If any of the arguments is not of the expected type.