chainer.training.extensions.snapshot¶
-
chainer.training.extensions.
snapshot
(savefun=None, filename='snapshot_iter_{.updater.iteration}', *, target=None, condition=None, writer=None, snapshot_on_error=False)[source]¶ Returns a trainer extension to take snapshots of the trainer.
This extension serializes the trainer object and saves it to the output directory. It is used to support resuming the training loop from the saved state.
This extension is called once per epoch by default. To take a snapshot at a different interval, a trigger object specifying the required interval can be passed along with this extension to the extend() method of the trainer.
The default priority is -100, which is lower than that of most built-in extensions.
Note
This extension first writes the serialized object to a temporary file and then rename it to the target file name. Thus, if the program stops right before the renaming, the temporary file might be left in the output directory.
- Parameters
savefun – Function to save the trainer. It takes two arguments: the output file path and the trainer object. It is
chainer.serializers.save_npz()
by default. Ifwriter
is specified, this argument must beNone
.filename (str) – Name of the file into which the trainer is serialized. It can be a format string, where the trainer object is passed to the
str.format()
method.target – Object to serialize. If it is not specified, it will be the trainer object.
condition – Condition object. It must be a callable object that returns boolean without any arguments. If it returns
True
, the snapshot will be done. If not, it will be skipped. The default is a function that always returnsTrue
.writer – Writer object. It must be a callable object. See below for the list of built-in writers. If
savefun
is other thanNone
, this argument must beNone
. In that case, aSimpleWriter
object instantiated with specifiedsavefun
argument will be used.snapshot_on_error (bool) – Whether to take a snapshot in case trainer loop has been failed.
- Returns
Snapshot extension object.
Using asynchronous writers
By specifying
writer
argument, writing operations can be made asynchronous, hiding I/O overhead of snapshots.>>> from chainer.training import extensions >>> writer = extensions.snapshot_writers.ProcessWriter() >>> trainer.extend(extensions.snapshot(writer=writer), trigger=(1, 'epoch'))
To change the format, such as npz or hdf5, you can pass a saving function as
savefun
argument of the writer.>>> from chainer.training import extensions >>> from chainer import serializers >>> writer = extensions.snapshot_writers.ProcessWriter( ... savefun=serializers.save_npz) >>> trainer.extend(extensions.snapshot(writer=writer), trigger=(1, 'epoch'))
This is the list of built-in snapshot writers.
chainer.training.extensions.snapshot_writers.SimpleWriter
chainer.training.extensions.snapshot_writers.ThreadWriter
chainer.training.extensions.snapshot_writers.ProcessWriter
chainer.training.extensions.snapshot_writers.ThreadQueueWriter
chainer.training.extensions.snapshot_writers.ProcessQueueWriter