import torch.utils.data as data
from PIL import Image
import os
import os.path
import six
import string
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
class LSUNClass(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
cache_file = '_cache_' + root.replace('/', '_')
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
else:
with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
def __getitem__(self, index):
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
imgbuf = txn.get(self.keys[index])
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.root + ')'
[docs]class LSUN(data.Dataset):
"""
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
Args:
root (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, classes='train',
transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if type(classes) == str and classes in dset_opts:
if classes == 'test':
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
elif type(classes) == list:
for c in classes:
c_short = c.split('_')
c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short)
if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts)))
else:
raise(ValueError('Unknown option for classes'))
self.classes = classes
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))
self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)
self.length = count
[docs] def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target) where target is the index of the target category.
"""
target = 0
sub = 0
for ind in self.indices:
if index < ind:
break
target += 1
sub = ind
db = self.dbs[target]
index = index - sub
if self.target_transform is not None:
target = self.target_transform(target)
img, _ = db[index]
return img, target
def __len__(self):
return self.length
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str