chainer.datasets.TransformDataset¶
-
class
chainer.datasets.
TransformDataset
(dataset, transform)[source]¶ Dataset that indexes the base dataset and transforms the data.
This dataset wraps the base dataset by modifying the behavior of the base dataset’s
__getitem__()
. Arrays returned by__getitem__()
of the base dataset with an integer as an argument are transformed by the given functiontransform
. Also,__len__()
returns the integer returned by the base dataset’s__len__()
.The function
transform
takes, as an argument,in_data
, which is the output of the base dataset’s__getitem__()
, and returns the transformed arrays as output. Please see the following example. Sincein_data
directly refers to the item in the dataset, take care thattransform
not modify it. For example, note that the line img = img - 0.5 bellow is correct since it makes a copy of img. However, it would be incorrect to use img -= 0.5 since that would update the contents of the item in the dataset in place, corrupting it.>>> from chainer.datasets import get_mnist >>> from chainer.datasets import TransformDataset >>> dataset, _ = get_mnist() >>> def transform(in_data): ... img, label = in_data ... img = img - 0.5 # scale to [-0.5, 0.5] ... return img, label >>> dataset = TransformDataset(dataset, transform)
- Parameters
dataset – The underlying dataset. The index of this dataset corresponds to the index of the base dataset. This object needs to support functions
__getitem__()
and__len__()
as described above.transform (callable) – A function that is called to transform values returned by the underlying dataset’s
__getitem__()
.
Methods
-
__getitem__
(index)[source]¶ Returns an example or a sequence of examples.
It implements the standard Python indexing and one-dimensional integer array indexing. It uses the
get_example()
method by default, but it may be overridden by the implementation to, for example, improve the slicing performance.- Parameters
index (int, slice, list or numpy.ndarray) – An index of an example or indexes of examples.
- Returns
If index is int, returns an example created by get_example. If index is either slice or one-dimensional list or numpy.ndarray, returns a list of examples created by get_example.
Example
>>> import numpy >>> from chainer import dataset >>> class SimpleDataset(dataset.DatasetMixin): ... def __init__(self, values): ... self.values = values ... def __len__(self): ... return len(self.values) ... def get_example(self, i): ... return self.values[i] ... >>> ds = SimpleDataset([0, 1, 2, 3, 4, 5]) >>> ds[1] # Access by int 1 >>> ds[1:3] # Access by slice [1, 2] >>> ds[[4, 0]] # Access by one-dimensional integer list [4, 0] >>> index = numpy.arange(3) >>> ds[index] # Access by one-dimensional integer numpy.ndarray [0, 1, 2]
-
get_example
(i)[source]¶ Returns the i-th example.
Implementations should override it. It should raise
IndexError
if the index is invalid.- Parameters
i (int) – The index of the example.
- Returns
The i-th example.