Class StandardSingleLossStep
Inherits From: StandardInputStep
Defined in tensorflow/contrib/distribute/python/step_fn.py
.
A step function that implements a training step for a feed forward network.
An instance of this class is intended to be used as a callable:
...
step = step_fn.StandardSingleLossStep(
dataset, loss_fn, optimizer, distribution)
# Run a single training step on a given DistributionStrategy:
step(distribution)
...
Args:
dataset_fn
: a function that returns a tf.data Dataset that produces the input for the model.loss_fn
: a function that takes a context and inputs as arguments. It returns the loss for those inputs.context
is an instance ofvalues.MultiStepContext
that will be passed whenloss_fn
is run.context
can be used to specify the outputs to be returned fromloss_fn
, among other things.optimizer
: an optimizer that implements an update rule.distribution
: aDistributionStrategy
object.
__init__
__init__(
dataset_fn,
loss_fn,
optimizer,
distribution,
iterations_per_step=1
)
Initialize self. See help(type(self)) for accurate signature.
Properties
distribution
Methods
tf.contrib.distribute.StandardSingleLossStep.__call__
__call__()
Perform one step of this training algorithm.