1+ import  torch 
2+ import  torchvision .transforms  as  transforms 
3+ import  torch .utils .data  as  data 
4+ import  os 
5+ import  pickle 
6+ import  numpy  as  np 
7+ import  nltk 
8+ from  PIL  import  Image 
9+ from  build_vocab  import  Vocabulary 
10+ from  pycocotools .coco  import  COCO 
11+ 
12+ 
13+ class  CocoDataset (data .Dataset ):
14+     """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 
15+     def  __init__ (self , root , json , vocab , transform = None ):
16+         """Set the path for images, captions and vocabulary wrapper. 
17+          
18+         Args: 
19+             root: image directory. 
20+             json: coco annotation file path. 
21+             vocab: vocabulary wrapper. 
22+             transform: image transformer. 
23+         """ 
24+         self .root  =  root 
25+         self .coco  =  COCO (json )
26+         self .ids  =  list (self .coco .anns .keys ())
27+         self .vocab  =  vocab 
28+         self .transform  =  transform 
29+ 
30+     def  __getitem__ (self , index ):
31+         """Returns one data pair (image and caption).""" 
32+         coco  =  self .coco 
33+         vocab  =  self .vocab 
34+         ann_id  =  self .ids [index ]
35+         caption  =  coco .anns [ann_id ]['caption' ]
36+         img_id  =  coco .anns [ann_id ]['image_id' ]
37+         path  =  coco .loadImgs (img_id )[0 ]['file_name' ]
38+ 
39+         image  =  Image .open (os .path .join (self .root , path )).convert ('RGB' )
40+         if  self .transform  is  not None :
41+             image  =  self .transform (image )
42+ 
43+         # Convert caption (string) to word ids. 
44+         tokens  =  nltk .tokenize .word_tokenize (str (caption ).lower ())
45+         caption  =  []
46+         caption .append (vocab ('<start>' ))
47+         caption .extend ([vocab (token ) for  token  in  tokens ])
48+         caption .append (vocab ('<end>' ))
49+         target  =  torch .Tensor (caption )
50+         return  image , target 
51+ 
52+     def  __len__ (self ):
53+         return  len (self .ids )
54+ 
55+ 
56+ def  collate_fn (data ):
57+     """Creates mini-batch tensors from the list of tuples (image, caption). 
58+      
59+     We should build custom collate_fn rather than using default collate_fn,  
60+     because merging caption (including padding) is not supported in default. 
61+ 
62+     Args: 
63+         data: list of tuple (image, caption).  
64+             - image: torch tensor of shape (3, 256, 256). 
65+             - caption: torch tensor of shape (?); variable length. 
66+ 
67+     Returns: 
68+         images: torch tensor of shape (batch_size, 3, 256, 256). 
69+         targets: torch tensor of shape (batch_size, padded_length). 
70+         lengths: list; valid length for each padded caption. 
71+     """ 
72+     # Sort a data list by caption length (descending order). 
73+     data .sort (key = lambda  x : len (x [1 ]), reverse = True )
74+     images , captions  =  zip (* data )
75+ 
76+     # Merge images (from tuple of 3D tensor to 4D tensor). 
77+     images  =  torch .stack (images , 0 )
78+ 
79+     # Merge captions (from tuple of 1D tensor to 2D tensor). 
80+     lengths  =  [len (cap ) for  cap  in  captions ]
81+     targets  =  torch .zeros (len (captions ), max (lengths )).long ()
82+     for  i , cap  in  enumerate (captions ):
83+         end  =  lengths [i ]
84+         targets [i , :end ] =  cap [:end ]        
85+     return  images , targets , lengths 
86+ 
87+ 
88+ def  get_loader (root , json , vocab , transform , batch_size , shuffle , num_workers ):
89+     """Returns torch.utils.data.DataLoader for custom coco dataset.""" 
90+     # COCO caption dataset 
91+     coco  =  CocoDataset (root = root ,
92+                        json = json ,
93+                        vocab = vocab ,
94+                        transform = transform )
95+     
96+     # Data loader for COCO dataset 
97+     # This will return (images, captions, lengths) for every iteration. 
98+     # images: tensor of shape (batch_size, 3, 224, 224). 
99+     # captions: tensor of shape (batch_size, padded_length). 
100+     # lengths: list indicating valid length for each caption. length is (batch_size). 
101+     data_loader  =  torch .utils .data .DataLoader (dataset = coco , 
102+                                               batch_size = batch_size ,
103+                                               shuffle = shuffle ,
104+                                               num_workers = num_workers ,
105+                                               collate_fn = collate_fn )
106+     return  data_loader 
0 commit comments