Source code for torchtext.data.batch
import torch
[docs]class Batch(object):
"""Defines a batch of examples along with its Fields.
Attributes:
batch_size: Number of examples in the batch.
dataset: A reference to the dataset object the examples come from
(which itself contains the dataset's Field objects).
train: Deprecated: this attribute is left for backwards compatibility,
however it is UNUSED as of the merger with pytorch 0.4.
input_fields: The names of the fields that are used as input for the model
target_fields: The names of the fields that are used as targets during
model training
Also stores the Variable for each column in the batch as an attribute.
"""
[docs] def __init__(self, data=None, dataset=None, device=None):
"""Create a Batch from a list of examples."""
if data is not None:
self.batch_size = len(data)
self.dataset = dataset
self.fields = dataset.fields.keys() # copy field names
self.input_fields = [k for k, v in dataset.fields.items() if
v is not None and not v.is_target]
self.target_fields = [k for k, v in dataset.fields.items() if
v is not None and v.is_target]
for (name, field) in dataset.fields.items():
if field is not None:
batch = [getattr(x, name) for x in data]
setattr(self, name, field.process(batch, device=device))
[docs] @classmethod
def fromvars(cls, dataset, batch_size, train=None, **kwargs):
"""Create a Batch directly from a number of Variables."""
batch = cls()
batch.batch_size = batch_size
batch.dataset = dataset
batch.fields = dataset.fields.keys()
for k, v in kwargs.items():
setattr(batch, k, v)
return batch
def __repr__(self):
return str(self)
def __str__(self):
if not self.__dict__:
return 'Empty {} instance'.format(torch.typename(self))
fields_to_index = filter(lambda field: field is not None, self.fields)
var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name))
for name in fields_to_index if hasattr(self, name)])
data_str = (' from {}'.format(self.dataset.name.upper())
if hasattr(self.dataset, 'name')
and isinstance(self.dataset.name, str) else '')
strt = '[{} of size {}{}]\n{}'.format(torch.typename(self),
self.batch_size, data_str, var_strs)
return '\n' + strt
def __len__(self):
return self.batch_size
def _get_field_values(self, fields):
if len(fields) == 0:
return None
elif len(fields) == 1:
return getattr(self, fields[0])
else:
return tuple(getattr(self, f) for f in fields)
def __iter__(self):
yield self._get_field_values(self.input_fields)
yield self._get_field_values(self.target_fields)
def _short_str(tensor):
# unwrap variable to tensor
if not torch.is_tensor(tensor):
# (1) unpack variable
if hasattr(tensor, 'data'):
tensor = getattr(tensor, 'data')
# (2) handle include_lengths
elif isinstance(tensor, tuple):
return str(tuple(_short_str(t) for t in tensor))
# (3) fallback to default str
else:
return str(tensor)
# copied from torch _tensor_str
size_str = 'x'.join(str(size) for size in tensor.size())
device_str = '' if not tensor.is_cuda else \
' (GPU {})'.format(tensor.get_device())
strt = '[{} of size {}{}]'.format(torch.typename(tensor),
size_str, device_str)
return strt