Skip to content

Commit 0dc0048

Browse files
committed
data loader added
1 parent ef803c7 commit 0dc0048

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
import torchvision.transforms as transforms
3+
import torch.utils.data as data
4+
import os
5+
import pickle
6+
import numpy as np
7+
import nltk
8+
from PIL import Image
9+
from build_vocab import Vocabulary
10+
from pycocotools.coco import COCO
11+
12+
13+
class CocoDataset(data.Dataset):
14+
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
15+
def __init__(self, root, json, vocab, transform=None):
16+
"""Set the path for images, captions and vocabulary wrapper.
17+
18+
Args:
19+
root: image directory.
20+
json: coco annotation file path.
21+
vocab: vocabulary wrapper.
22+
transform: image transformer.
23+
"""
24+
self.root = root
25+
self.coco = COCO(json)
26+
self.ids = list(self.coco.anns.keys())
27+
self.vocab = vocab
28+
self.transform = transform
29+
30+
def __getitem__(self, index):
31+
"""Returns one data pair (image and caption)."""
32+
coco = self.coco
33+
vocab = self.vocab
34+
ann_id = self.ids[index]
35+
caption = coco.anns[ann_id]['caption']
36+
img_id = coco.anns[ann_id]['image_id']
37+
path = coco.loadImgs(img_id)[0]['file_name']
38+
39+
image = Image.open(os.path.join(self.root, path)).convert('RGB')
40+
if self.transform is not None:
41+
image = self.transform(image)
42+
43+
# Convert caption (string) to word ids.
44+
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
45+
caption = []
46+
caption.append(vocab('<start>'))
47+
caption.extend([vocab(token) for token in tokens])
48+
caption.append(vocab('<end>'))
49+
target = torch.Tensor(caption)
50+
return image, target
51+
52+
def __len__(self):
53+
return len(self.ids)
54+
55+
56+
def collate_fn(data):
57+
"""Creates mini-batch tensors from the list of tuples (image, caption).
58+
59+
We should build custom collate_fn rather than using default collate_fn,
60+
because merging caption (including padding) is not supported in default.
61+
62+
Args:
63+
data: list of tuple (image, caption).
64+
- image: torch tensor of shape (3, 256, 256).
65+
- caption: torch tensor of shape (?); variable length.
66+
67+
Returns:
68+
images: torch tensor of shape (batch_size, 3, 256, 256).
69+
targets: torch tensor of shape (batch_size, padded_length).
70+
lengths: list; valid length for each padded caption.
71+
"""
72+
# Sort a data list by caption length (descending order).
73+
data.sort(key=lambda x: len(x[1]), reverse=True)
74+
images, captions = zip(*data)
75+
76+
# Merge images (from tuple of 3D tensor to 4D tensor).
77+
images = torch.stack(images, 0)
78+
79+
# Merge captions (from tuple of 1D tensor to 2D tensor).
80+
lengths = [len(cap) for cap in captions]
81+
targets = torch.zeros(len(captions), max(lengths)).long()
82+
for i, cap in enumerate(captions):
83+
end = lengths[i]
84+
targets[i, :end] = cap[:end]
85+
return images, targets, lengths
86+
87+
88+
def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
89+
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
90+
# COCO caption dataset
91+
coco = CocoDataset(root=root,
92+
json=json,
93+
vocab=vocab,
94+
transform=transform)
95+
96+
# Data loader for COCO dataset
97+
# This will return (images, captions, lengths) for every iteration.
98+
# images: tensor of shape (batch_size, 3, 224, 224).
99+
# captions: tensor of shape (batch_size, padded_length).
100+
# lengths: list indicating valid length for each caption. length is (batch_size).
101+
data_loader = torch.utils.data.DataLoader(dataset=coco,
102+
batch_size=batch_size,
103+
shuffle=shuffle,
104+
num_workers=num_workers,
105+
collate_fn=collate_fn)
106+
return data_loader

0 commit comments

Comments
 (0)