Shortcuts

Source code for torchtext.data.dataset

import io
import os
import zipfile
import tarfile
import gzip
import shutil
from functools import partial

import torch.utils.data

from .utils import RandomShuffler
from .example import Example
from ..utils import download_from_url, unicode_csv_reader


[docs]class Dataset(torch.utils.data.Dataset): """Defines a dataset composed of Examples along with its Fields. Attributes: sort_key (callable): A key to use for sorting dataset examples for batching together examples with similar lengths to minimize padding. examples (list(Example)): The examples in this dataset. fields (dict[str, Field]): Contains the name of each column or field, together with the corresponding Field object. Two fields with the same Field object will have a shared vocabulary. """ sort_key = None
[docs] def __init__(self, examples, fields, filter_pred=None): """Create a dataset from a list of Examples and Fields. Arguments: examples: List of Examples. fields (List(tuple(str, Field))): The Fields to use in this tuple. The string is a field name, and the Field is the associated field. filter_pred (callable or None): Use only examples for which filter_pred(example) is True, or use all examples if None. Default is None. """ if filter_pred is not None: make_list = isinstance(examples, list) examples = filter(filter_pred, examples) if make_list: examples = list(examples) self.examples = examples self.fields = dict(fields) # Unpack field tuples for n, f in list(self.fields.items()): if isinstance(n, tuple): self.fields.update(zip(n, f)) del self.fields[n]
[docs] @classmethod def splits(cls, path=None, root='.data', train=None, validation=None, test=None, **kwargs): """Create Dataset objects for multiple splits of a dataset. Arguments: path (str): Common prefix of the splits' file paths, or None to use the result of cls.download(root). root (str): Root dataset storage directory. Default is '.data'. train (str): Suffix to add to path for the train set, or None for no train set. Default is None. validation (str): Suffix to add to path for the validation set, or None for no validation set. Default is None. test (str): Suffix to add to path for the test set, or None for no test set. Default is None. Remaining keyword arguments: Passed to the constructor of the Dataset (sub)class being used. Returns: Tuple[Dataset]: Datasets for train, validation, and test splits in that order, if provided. """ if path is None: path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), **kwargs) test_data = None if test is None else cls( os.path.join(path, test), **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None)
[docs] def split(self, split_ratio=0.7, stratified=False, strata_field='label', random_state=None): """Create train-test(-valid?) splits from the instance's examples. Arguments: split_ratio (float or List of floats): a number [0, 1] denoting the amount of data to be used for the training split (rest is used for test), or a list of numbers denoting the relative sizes of train, test and valid splits respectively. If the relative size for valid is missing, only the train-test split is returned. Default is 0.7 (for the train set). stratified (bool): whether the sampling should be stratified. Default is False. strata_field (str): name of the examples Field stratified over. Default is 'label' for the conventional label field. random_state (tuple): the random seed used for shuffling. A return value of `random.getstate()`. Returns: Tuple[Dataset]: Datasets for train, validation, and test splits in that order, if the splits are provided. """ train_ratio, test_ratio, val_ratio = check_split_ratio(split_ratio) # For the permutations rnd = RandomShuffler(random_state) if not stratified: train_data, test_data, val_data = rationed_split(self.examples, train_ratio, test_ratio, val_ratio, rnd) else: if strata_field not in self.fields: raise ValueError("Invalid field name for strata_field {}" .format(strata_field)) strata = stratify(self.examples, strata_field) train_data, test_data, val_data = [], [], [] for group in strata: # Stratify each group and add together the indices. group_train, group_test, group_val = rationed_split(group, train_ratio, test_ratio, val_ratio, rnd) train_data += group_train test_data += group_test val_data += group_val splits = tuple(Dataset(d, self.fields) for d in (train_data, val_data, test_data) if d) # In case the parent sort key isn't none if self.sort_key: for subset in splits: subset.sort_key = self.sort_key return splits
def __getitem__(self, i): return self.examples[i] def __len__(self): try: return len(self.examples) except TypeError: return 2**32 def __iter__(self): for x in self.examples: yield x def __getattr__(self, attr): if attr in self.fields: for x in self.examples: yield getattr(x, attr)
[docs] @classmethod def download(cls, root, check=None): """Download and unzip an online archive (.zip, .gz, or .tgz). Arguments: root (str): Folder to download data to. check (str or None): Folder whose existence indicates that the dataset has already been downloaded, or None to check the existence of root/{cls.name}. Returns: str: Path to extracted dataset. """ path = os.path.join(root, cls.name) check = path if check is None else check if not os.path.isdir(check): for url in cls.urls: if isinstance(url, tuple): url, filename = url else: filename = os.path.basename(url) zpath = os.path.join(path, filename) if not os.path.isfile(zpath): if not os.path.exists(os.path.dirname(zpath)): os.makedirs(os.path.dirname(zpath)) print('downloading {}'.format(filename)) download_from_url(url, zpath) zroot, ext = os.path.splitext(zpath) _, ext_inner = os.path.splitext(zroot) if ext == '.zip': with zipfile.ZipFile(zpath, 'r') as zfile: print('extracting') zfile.extractall(path) # tarfile cannot handle bare .gz files elif ext == '.tgz' or ext == '.gz' and ext_inner == '.tar': with tarfile.open(zpath, 'r:gz') as tar: dirs = [member for member in tar.getmembers()] tar.extractall(path=path, members=dirs) elif ext == '.gz': with gzip.open(zpath, 'rb') as gz: with open(zroot, 'wb') as uncompressed: shutil.copyfileobj(gz, uncompressed) return os.path.join(path, cls.dirname)
[docs] def filter_examples(self, field_names): """Remove unknown words from dataset examples with respect to given field. Arguments: field_names (list(str)): Within example only the parts with field names in field_names will have their unknown words deleted. """ for i, example in enumerate(self.examples): for field_name in field_names: vocab = set(self.fields[field_name].vocab.stoi) text = getattr(example, field_name) example_part = [word for word in text if word in vocab] setattr(example, field_name, example_part) self.examples[i] = example
[docs]class TabularDataset(Dataset): """Defines a Dataset of columns stored in CSV, TSV, or JSON format."""
[docs] def __init__(self, path, format, fields, skip_header=False, csv_reader_params={}, **kwargs): """Create a TabularDataset given a path, file format, and field list. Arguments: path (str): Path to the data file. format (str): The format of the data file. One of "CSV", "TSV", or "JSON" (case-insensitive). fields (list(tuple(str, Field)) or dict[str: tuple(str, Field)]: If using a list, the format must be CSV or TSV, and the values of the list should be tuples of (name, field). The fields should be in the same order as the columns in the CSV or TSV file, while tuples of (name, None) represent columns that will be ignored. If using a dict, the keys should be a subset of the JSON keys or CSV/TSV columns, and the values should be tuples of (name, field). Keys not present in the input dictionary are ignored. This allows the user to rename columns from their JSON/CSV/TSV key names and also enables selecting a subset of columns to load. skip_header (bool): Whether to skip the first line of the input file. csv_reader_params(dict): Parameters to pass to the csv reader. Only relevant when format is csv or tsv. See https://docs.python.org/3/library/csv.html#csv.reader for more details. """ format = format.lower() make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format] with io.open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f, **csv_reader_params) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params) else: reader = f if format in ['csv', 'tsv'] and isinstance(fields, dict): if skip_header: raise ValueError('When using a dict to specify fields with a {} file,' 'skip_header must be False and' 'the file must have a header.'.format(format)) header = next(reader) field_to_index = {f: header.index(f) for f in fields.keys()} make_example = partial(make_example, field_to_index=field_to_index) if skip_header: next(reader) examples = [make_example(line, fields) for line in reader] if isinstance(fields, dict): fields, field_dict = [], fields for field in field_dict.values(): if isinstance(field, list): fields.extend(field) else: fields.append(field) super(TabularDataset, self).__init__(examples, fields, **kwargs)
def check_split_ratio(split_ratio): """Check that the split ratio argument is not malformed""" valid_ratio = 0. if isinstance(split_ratio, float): # Only the train set relative ratio is provided # Assert in bounds, validation size is zero assert 0. < split_ratio < 1., ( "Split ratio {} not between 0 and 1".format(split_ratio)) test_ratio = 1. - split_ratio return (split_ratio, test_ratio, valid_ratio) elif isinstance(split_ratio, list): # A list of relative ratios is provided length = len(split_ratio) assert length == 2 or length == 3, ( "Length of split ratio list should be 2 or 3, got {}".format(split_ratio)) # Normalize if necessary ratio_sum = sum(split_ratio) if not ratio_sum == 1.: split_ratio = [float(ratio) / ratio_sum for ratio in split_ratio] if length == 2: return tuple(split_ratio + [valid_ratio]) return tuple(split_ratio) else: raise ValueError('Split ratio must be float or a list, got {}' .format(type(split_ratio))) def stratify(examples, strata_field): # The field has to be hashable otherwise this doesn't work # There's two iterations over the whole dataset here, which can be # reduced to just one if a dedicated method for stratified splitting is used unique_strata = set(getattr(example, strata_field) for example in examples) strata_maps = {s: [] for s in unique_strata} for example in examples: strata_maps[getattr(example, strata_field)].append(example) return list(strata_maps.values()) def rationed_split(examples, train_ratio, test_ratio, val_ratio, rnd): """Create a random permutation of examples, then split them by ratios Arguments: examples: a list of data train_ratio, test_ratio, val_ratio: split fractions. rnd: a random shuffler Examples: >>> examples = [] >>> train_ratio, test_ratio, val_ratio = 0.7, 0.2, 0.1 >>> rnd = torchtext.data.dataset.RandomShuffler(None) >>> train_examples, test_examples, valid_examples = \ torchtext.data.dataset.rationed_split(examples, train_ratio, test_ratio, val_ratio, rnd) """ N = len(examples) randperm = rnd(range(N)) train_len = int(round(train_ratio * N)) # Due to possible rounding problems if not val_ratio: test_len = N - train_len else: test_len = int(round(test_ratio * N)) indices = (randperm[:train_len], # Train randperm[train_len:train_len + test_len], # Test randperm[train_len + test_len:]) # Validation # There's a possibly empty list for the validation set data = tuple([examples[i] for i in index] for index in indices) return data

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources