1+ import torch
2+ import torchvision .transforms as transforms
3+ import torch .utils .data as data
4+ import os
5+ import pickle
6+ import numpy as np
7+ import nltk
8+ from PIL import Image
9+ from vocab import Vocabulary
10+ from pycocotools .coco import COCO
11+
12+
13+ class CocoDataset (data .Dataset ):
14+ """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
15+ def __init__ (self , root , json , vocab , transform = None ):
16+ """
17+ Args:
18+ root: image directory.
19+ json: coco annotation file path.
20+ vocab: vocabulary wrapper.
21+ transform: transformer for image.
22+ """
23+ self .root = root
24+ self .coco = COCO (json )
25+ self .ids = list (self .coco .anns .keys ())
26+ self .vocab = vocab
27+ self .transform = transform
28+
29+ def __getitem__ (self , index ):
30+ """This function should return one data pair(image and caption)."""
31+ coco = self .coco
32+ vocab = self .vocab
33+ ann_id = self .ids [index ]
34+ caption = coco .anns [ann_id ]['caption' ]
35+ img_id = coco .anns [ann_id ]['image_id' ]
36+ path = coco .loadImgs (img_id )[0 ]['file_name' ]
37+
38+ image = Image .open (os .path .join (self .root , path )).convert ('RGB' )
39+ if self .transform is not None :
40+ image = self .transform (image )
41+
42+ # Convert caption (string) to word ids.
43+ tokens = nltk .tokenize .word_tokenize (str (caption ).lower ())
44+ caption = []
45+ caption .append (vocab ('<start>' ))
46+ caption .extend ([vocab (token ) for token in tokens ])
47+ caption .append (vocab ('<end>' ))
48+ target = torch .Tensor (caption )
49+ return image , target
50+
51+ def __len__ (self ):
52+ return len (self .ids )
53+
54+
55+ def collate_fn (data ):
56+ """Build mini-batch tensors from a list of (image, caption) tuples.
57+ Args:
58+ data: list of (image, caption) tuple.
59+ - image: torch tensor of shape (3, 256, 256).
60+ - caption: torch tensor of shape (?); variable length.
61+
62+ Returns:
63+ images: torch tensor of shape (batch_size, 3, 256, 256).
64+ targets: torch tensor of shape (batch_size, padded_length).
65+ lengths: list; valid length for each padded caption.
66+ """
67+ # Sort a data list by caption length
68+ data .sort (key = lambda x : len (x [1 ]), reverse = True )
69+ images , captions = zip (* data )
70+
71+ # Merge images (convert tuple of 3D tensor to 4D tensor)
72+ images = torch .stack (images , 0 )
73+
74+ # Merget captions (convert tuple of 1D tensor to 2D tensor)
75+ lengths = [len (cap ) for cap in captions ]
76+ targets = torch .zeros (len (captions ), max (lengths )).long ()
77+ for i , cap in enumerate (captions ):
78+ end = lengths [i ]
79+ targets [i , :end ] = cap [:end ]
80+ return images , targets , lengths
81+
82+
83+ def get_loader (root , json , vocab , transform , batch_size = 100 , shuffle = True , num_workers = 2 ):
84+ """Returns torch.utils.data.DataLoader for custom coco dataset."""
85+ # COCO custom dataset
86+ coco = CocoDataset (root = root ,
87+ json = json ,
88+ vocab = vocab ,
89+ transform = transform )
90+
91+ # Data loader
92+ data_loader = torch .utils .data .DataLoader (dataset = coco ,
93+ batch_size = batch_size ,
94+ shuffle = True ,
95+ num_workers = num_workers ,
96+ collate_fn = collate_fn )
97+ return data_loader
0 commit comments