Shortcuts

Source code for torchtext.datasets.sst

import os

from .. import data


[docs]class SST(data.Dataset): urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'] dirname = 'trees' name = 'sst' @staticmethod def sort_key(ex): return len(ex.text) def __init__(self, path, text_field, label_field, subtrees=False, fine_grained=False, **kwargs): """Create an SST dataset instance given a path and fields. Arguments: path: Path to the data file text_field: The field that will be used for text data. label_field: The field that will be used for label data. subtrees: Whether to include sentiment-tagged subphrases in addition to complete examples. Default: False. fine_grained: Whether to use 5-class instead of 3-class labeling. Default: False. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ fields = [('text', text_field), ('label', label_field)] def get_label_str(label): pre = 'very ' if fine_grained else '' return {'0': pre + 'negative', '1': 'negative', '2': 'neutral', '3': 'positive', '4': pre + 'positive', None: None}[label] label_field.preprocessing = data.Pipeline(get_label_str) with open(os.path.expanduser(path)) as f: if subtrees: examples = [ex for line in f for ex in data.Example.fromtree(line, fields, True)] else: examples = [data.Example.fromtree(line, fields) for line in f] super(SST, self).__init__(examples, fields, **kwargs)
[docs] @classmethod def splits(cls, text_field, label_field, root='.data', train='train.txt', validation='dev.txt', test='test.txt', train_subtrees=False, **kwargs): """Create dataset objects for splits of the SST dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. train: The filename of the train data. Default: 'train.txt'. validation: The filename of the validation data, or None to not load the validation set. Default: 'dev.txt'. test: The filename of the test data, or None to not load the test set. Default: 'test.txt'. train_subtrees: Whether to use all subtrees in the training set. Default: False. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), text_field, label_field, subtrees=train_subtrees, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), text_field, label_field, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), text_field, label_field, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None)
[docs] @classmethod def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): """Create iterator objects for splits of the SST dataset. Arguments: batch_size: Batch_size device: Device to create batches on. Use - 1 for CPU and None for the currently active GPU device. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. vectors: one of the available pretrained vectors or a list with each element one of the available pretrained vectors (see Vocab.load_vectors) Remaining keyword arguments: Passed to the splits method. """ TEXT = data.Field() LABEL = data.Field(sequential=False) train, val, test = cls.splits(TEXT, LABEL, root=root, **kwargs) TEXT.build_vocab(train, vectors=vectors) LABEL.build_vocab(train) return data.BucketIterator.splits( (train, val, test), batch_size=batch_size, device=device)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources