1+ import torch
2+ import torchvision .transforms as transforms
3+ import torch .utils .data as data
4+ import os
5+ import pickle
6+ import numpy as np
7+ import nltk
8+ from PIL import Image
9+ from build_vocab import Vocabulary
10+ from pycocotools .coco import COCO
11+
12+
13+ class CocoDataset (data .Dataset ):
14+ """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
15+ def __init__ (self , root , json , vocab , transform = None ):
16+ """Set the path for images, captions and vocabulary wrapper.
17+
18+ Args:
19+ root: image directory.
20+ json: coco annotation file path.
21+ vocab: vocabulary wrapper.
22+ transform: image transformer.
23+ """
24+ self .root = root
25+ self .coco = COCO (json )
26+ self .ids = list (self .coco .anns .keys ())
27+ self .vocab = vocab
28+ self .transform = transform
29+
30+ def __getitem__ (self , index ):
31+ """Returns one data pair (image and caption)."""
32+ coco = self .coco
33+ vocab = self .vocab
34+ ann_id = self .ids [index ]
35+ caption = coco .anns [ann_id ]['caption' ]
36+ img_id = coco .anns [ann_id ]['image_id' ]
37+ path = coco .loadImgs (img_id )[0 ]['file_name' ]
38+
39+ image = Image .open (os .path .join (self .root , path )).convert ('RGB' )
40+ if self .transform is not None :
41+ image = self .transform (image )
42+
43+ # Convert caption (string) to word ids.
44+ tokens = nltk .tokenize .word_tokenize (str (caption ).lower ())
45+ caption = []
46+ caption .append (vocab ('<start>' ))
47+ caption .extend ([vocab (token ) for token in tokens ])
48+ caption .append (vocab ('<end>' ))
49+ target = torch .Tensor (caption )
50+ return image , target
51+
52+ def __len__ (self ):
53+ return len (self .ids )
54+
55+
56+ def collate_fn (data ):
57+ """Creates mini-batch tensors from the list of tuples (image, caption).
58+
59+ We should build custom collate_fn rather than using default collate_fn,
60+ because merging caption (including padding) is not supported in default.
61+
62+ Args:
63+ data: list of tuple (image, caption).
64+ - image: torch tensor of shape (3, 256, 256).
65+ - caption: torch tensor of shape (?); variable length.
66+
67+ Returns:
68+ images: torch tensor of shape (batch_size, 3, 256, 256).
69+ targets: torch tensor of shape (batch_size, padded_length).
70+ lengths: list; valid length for each padded caption.
71+ """
72+ # Sort a data list by caption length (descending order).
73+ data .sort (key = lambda x : len (x [1 ]), reverse = True )
74+ images , captions = zip (* data )
75+
76+ # Merge images (from tuple of 3D tensor to 4D tensor).
77+ images = torch .stack (images , 0 )
78+
79+ # Merge captions (from tuple of 1D tensor to 2D tensor).
80+ lengths = [len (cap ) for cap in captions ]
81+ targets = torch .zeros (len (captions ), max (lengths )).long ()
82+ for i , cap in enumerate (captions ):
83+ end = lengths [i ]
84+ targets [i , :end ] = cap [:end ]
85+ return images , targets , lengths
86+
87+
88+ def get_loader (root , json , vocab , transform , batch_size , shuffle , num_workers ):
89+ """Returns torch.utils.data.DataLoader for custom coco dataset."""
90+ # COCO caption dataset
91+ coco = CocoDataset (root = root ,
92+ json = json ,
93+ vocab = vocab ,
94+ transform = transform )
95+
96+ # Data loader for COCO dataset
97+ # This will return (images, captions, lengths) for every iteration.
98+ # images: tensor of shape (batch_size, 3, 224, 224).
99+ # captions: tensor of shape (batch_size, padded_length).
100+ # lengths: list indicating valid length for each caption. length is (batch_size).
101+ data_loader = torch .utils .data .DataLoader (dataset = coco ,
102+ batch_size = batch_size ,
103+ shuffle = shuffle ,
104+ num_workers = num_workers ,
105+ collate_fn = collate_fn )
106+ return data_loader
0 commit comments