tf.contrib.model_pruning.train(
train_op,
logdir,
mask_update_op,
train_step_fn=train_step,
train_step_kwargs=_USE_DEFAULT,
log_every_n_steps=1,
graph=None,
master='',
is_chief=True,
global_step=None,
number_of_steps=None,
init_op=_USE_DEFAULT,
init_feed_dict=None,
local_init_op=_USE_DEFAULT,
init_fn=None,
ready_op=_USE_DEFAULT,
summary_op=_USE_DEFAULT,
save_summaries_secs=600,
summary_writer=_USE_DEFAULT,
startup_delay_steps=0,
saver=None,
save_interval_secs=600,
sync_optimizer=None,
session_config=None,
trace_every_n_steps=None
)
Defined in tensorflow/contrib/model_pruning/python/learning.py.
Wrapper around tf-slim's train function.
Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied synchronously. Otherwise, gradient updates are applied asynchronous.
Args:
train_op: ATensorthat, when executed, will apply the gradients and return the loss value.logdir: The directory where training logs are written to. If None, model checkpoints and summaries will not be written.mask_update_op: Operation that upon execution updates the weight masks and thresholds.train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current session, thetrain_opTensor, a global stepTensorand a dictionary.train_step_kwargs: A dictionary which is passed to thetrain_step_fn. By default, twoBoolean, scalar ops called "should_stop" and "should_log" are provided.log_every_n_steps: The frequency, in terms of global steps, that the loss and global step and logged.graph: The graph to pass to the supervisor. If no graph is supplied the default graph is used.master: The address of the tensorflow master.is_chief: Specifies whether or not the training is being run by the primary replica during replica training.global_step: TheTensorrepresenting the global step. If left asNone, then slim.variables.get_or_create_global_step() is used.number_of_steps: The max number of gradient steps to take during training, as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely.init_op: The initialization operation. If left to its default value, then the session is initialized by callingtf.global_variables_initializer().init_feed_dict: A feed dictionary to use when executing theinit_op.local_init_op: The local initialization operation. If left to its default value, then the session is initialized by callingtf.local_variables_initializer()andtf.tables_initializer().init_fn: An optional callable to be executed afterinit_opis called. The callable must accept one argument, the session being initialized.ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by callingtf.report_uninitialized_variables().summary_op: The summary operation.save_summaries_secs: How often, in seconds, to save summaries.summary_writer:SummaryWriterto use. Can beNoneto indicate that no summaries should be written. If unset, we create a SummaryWriter.startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied.saver: Saver to save checkpoints. If None, a default one will be created and used.save_interval_secs: How often, in seconds, to save the model tologdir.sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of them. If the argument is supplied, gradient updates will be synchronous. If left asNone, gradient updates will be asynchronous.session_config: An instance oftf.ConfigProtothat will be used to configure theSession. If left asNone, the default will be used.trace_every_n_steps: produce and save aTimelinein Chrome trace format and add it to the summaries everytrace_every_n_steps. If None, no trace information will be produced or saved.
Returns:
the value of the loss function after training.
Raises:
ValueError: iftrain_opis empty or ifstartup_delay_stepsis non-zero whensync_optimizeris supplied, ifnumber_of_stepsis negative, or iftrace_every_n_stepsis notNoneand nologdiris provided.