chainer.training.Trainer¶
-
class
chainer.training.
Trainer
(updater, stop_trigger=None, out='result', extensions=None)[source]¶ The standard training loop in Chainer.
Trainer is an implementation of a training loop. Users can invoke the training by calling the
run()
method.Each iteration of the training loop proceeds as follows.
Update of the parameters. It includes the mini-batch loading, forward and backward computations, and an execution of the update formula. These are all done by the update object held by the trainer.
Invocation of trainer extensions in the descending order of their priorities. A trigger object is attached to each extension, and it decides at each iteration whether the extension should be executed. Trigger objects are callable objects that take the trainer object as the argument and return a boolean value indicating whether the extension should be called or not.
Extensions are callable objects that take the trainer object as the argument. There are three ways to define custom extensions: inheriting the
Extension
class, decorating functions bymake_extension()
, and defining any callable including lambda functions. SeeExtension
for more details on custom extensions and how to configure them.Users can register extensions to the trainer by calling the
extend()
method, where some configurations can be added.Trigger object, which is also explained above. In most cases,
IntervalTrigger
is used, in which case users can simply specify a tuple of the interval length and its unit, like(1000, 'iteration')
or(1, 'epoch')
.The order of execution of extensions is determined by their priorities. Extensions of higher priorities are invoked earlier. There are three standard values for the priorities:
PRIORITY_WRITER
. This is the priority for extensions that write some records to theobservation
dictionary. It includes cases that the extension directly adds values to the observation dictionary, or the extension uses thechainer.report()
function to report values to the observation dictionary.PRIORITY_EDITOR
. This is the priority for extensions that edit theobservation
dictionary based on already reported values.PRIORITY_READER
. This is the priority for extensions that only read records from theobservation
dictionary. This is also suitable for extensions that do not use theobservation
dictionary at all.
The current state of the trainer object and objects handled by the trainer can be serialized through the standard serialization protocol of Chainer. It enables us to easily suspend and resume the training loop.
>>> serializers.save_npz('my.trainer', trainer) # To suspend and save >>> serializers.load_npz('my.trainer', trainer) # To load and resume
The
snapshot()
method makes regular snapshots of theTrainer
object during training.Note
The serialization does not recover everything of the training loop. It only recovers the states which change over the training (e.g. parameters, optimizer states, the batch iterator state, extension states, etc.). You must initialize the objects correctly before deserializing the states.
On the other hand, it means that users can change the settings on deserialization. For example, the exit condition can be changed on the deserialization, so users can train the model for some iterations, suspend it, and then resume it with larger number of total iterations.
During the training, it also creates a
Reporter
object to store observed values on each update. For each iteration, it creates a fresh observation dictionary and stores it in theobservation
attribute.Links of the target model of each optimizer are registered to the reporter object as observers, where the name of each observer is constructed as the format
<optimizer name><link name>
. The link name is given by thechainer.Link.namedlink()
method, which represents the path to each link in the hierarchy. Other observers can be registered by accessing the reporter object via thereporter
attribute.The default trainer is plain, i.e., it does not contain any extensions.
- Parameters
updater (Updater) – Updater object. It defines how to update the models.
stop_trigger – Trigger that determines when to stop the training loop. If it is not callable, it is passed to
IntervalTrigger
.out – Output directory.
extensions – Extensions registered to the trainer.
- Variables
updater – The updater object for this trainer.
stop_trigger – Trigger that determines when to stop the training loop. The training loop stops at the iteration on which this trigger returns
True
.observation – Observation of values made at the last update. See the
Reporter
class for details.out – Output directory.
reporter – Reporter object to report observed values.
Methods
-
extend
(extension, name=None, trigger=None, priority=None, **kwargs)[source]¶ Registers an extension to the trainer.
Extension
is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations.If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is
_N
where N is the ordinal of the extensions.See
Extension
for the interface of extensions.- Parameters
extension – Extension to register.
name (str) – Name of the extension. If it is omitted, the
Extension.name
attribute of the extension is used or theExtension.default_name
attribute of the extension if name is is set to None or is undefined. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above.trigger (tuple or Trigger) – Trigger object that determines when to invoke the extension. If it is
None
,extension.trigger
is used instead. If it isNone
and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed toIntervalTrigger
to build an interval trigger.priority (int) – Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is
None
,extension.priority
is used instead.
-
get_extension
(name)[source]¶ Returns the extension of a given name.
- Parameters
name (str) – Name of the extension.
- Returns
Extension.
-
run
(show_loop_exception_msg=True)[source]¶ Executes the training loop.
This method is the core of
Trainer
. It executes the whole loop of training the models.Note that this method cannot run multiple times for one trainer object.
Attributes
-
elapsed_time
¶ Total time used for the training.
The time is in seconds. If the training is resumed from snapshot, it includes the time of all the previous training to get the current state of the trainer.