View source on GitHub |
A mixin for putting Python state in an object-based checkpoint.
This is an abstract class which allows extensions to TensorFlow's object-based
checkpointing (see tf.train.Checkpoint
). For example a wrapper for NumPy
arrays:
import io
import numpy
class NumpyWrapper(tf.train.experimental.PythonState):
def __init__(self, array):
self.array = array
def serialize(self):
string_file = io.BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
serialized = string_file.getvalue()
finally:
string_file.close()
return serialized
def deserialize(self, string_value):
string_file = io.BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
Instances of NumpyWrapper
are checkpointable objects, and will be saved and
restored from checkpoints along with TensorFlow state like variables.
root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
save_path = root.save(prefix)
root.numpy.array *= 2.
assert [2.] == root.numpy.array
root.restore(save_path)
assert [1.] == root.numpy.array
deserialize
deserialize(
string_value
)
Callback to deserialize the object.
serialize
serialize()
Callback to serialize the object. Returns a string.