1+ import  torch 
2+ import  torchvision .transforms  as  transforms 
3+ import  torch .utils .data  as  data 
4+ import  os 
5+ import  pickle 
6+ import  numpy  as  np 
7+ import  nltk 
8+ from  PIL  import  Image 
9+ from  vocab  import  Vocabulary 
10+ from  pycocotools .coco  import  COCO 
11+ 
12+ 
13+ class  CocoDataset (data .Dataset ):
14+     """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 
15+     def  __init__ (self , root , json , vocab , transform = None ):
16+         """ 
17+         Args: 
18+             root: image directory. 
19+             json: coco annotation file path. 
20+             vocab: vocabulary wrapper. 
21+             transform: transformer for image. 
22+         """ 
23+         self .root  =  root 
24+         self .coco  =  COCO (json )
25+         self .ids  =  list (self .coco .anns .keys ())
26+         self .vocab  =  vocab 
27+         self .transform  =  transform 
28+ 
29+     def  __getitem__ (self , index ):
30+         """This function should return one data pair(image and caption).""" 
31+         coco  =  self .coco 
32+         vocab  =  self .vocab 
33+         ann_id  =  self .ids [index ]
34+         caption  =  coco .anns [ann_id ]['caption' ]
35+         img_id  =  coco .anns [ann_id ]['image_id' ]
36+         path  =  coco .loadImgs (img_id )[0 ]['file_name' ]
37+ 
38+         image  =  Image .open (os .path .join (self .root , path )).convert ('RGB' )
39+         if  self .transform  is  not   None :
40+             image  =  self .transform (image )
41+             
42+         # Convert caption (string) to word ids. 
43+         tokens  =  nltk .tokenize .word_tokenize (str (caption ).lower ())
44+         caption  =  []
45+         caption .append (vocab ('<start>' ))
46+         caption .extend ([vocab (token ) for  token  in  tokens ])
47+         caption .append (vocab ('<end>' ))
48+         target  =  torch .Tensor (caption )
49+         return  image , target 
50+ 
51+     def  __len__ (self ):
52+         return  len (self .ids )
53+ 
54+     
55+ def  collate_fn (data ):
56+     """Build mini-batch tensors from a list of (image, caption) tuples. 
57+     Args: 
58+         data: list of (image, caption) tuple.  
59+             - image: torch tensor of shape (3, 256, 256). 
60+             - caption: torch tensor of shape (?); variable length. 
61+          
62+     Returns: 
63+         images: torch tensor of shape (batch_size, 3, 256, 256). 
64+         targets: torch tensor of shape (batch_size, padded_length). 
65+         lengths: list; valid length for each padded caption. 
66+     """ 
67+     # Sort a data list by caption length 
68+     data .sort (key = lambda  x : len (x [1 ]), reverse = True )
69+     images , captions  =  zip (* data )
70+     
71+     # Merge images (convert tuple of 3D tensor to 4D tensor) 
72+     images  =  torch .stack (images , 0 )
73+     
74+     # Merget captions (convert tuple of 1D tensor to 2D tensor) 
75+     lengths  =  [len (cap ) for  cap  in  captions ]
76+     targets  =  torch .zeros (len (captions ), max (lengths )).long ()
77+     for  i , cap  in  enumerate (captions ):
78+         end  =  lengths [i ]
79+         targets [i , :end ] =  cap [:end ]        
80+     return  images , targets , lengths 
81+ 
82+ 
83+ def  get_loader (root , json , vocab , transform , batch_size = 100 , shuffle = True , num_workers = 2 ):
84+     """Returns torch.utils.data.DataLoader for custom coco dataset.""" 
85+     # COCO custom dataset 
86+     coco  =  CocoDataset (root = root ,
87+                        json = json ,
88+                        vocab  =  vocab ,
89+                        transform = transform )
90+     
91+     # Data loader 
92+     data_loader  =  torch .utils .data .DataLoader (dataset = coco , 
93+                                               batch_size = batch_size ,
94+                                               shuffle = True ,
95+                                               num_workers = num_workers ,
96+                                               collate_fn = collate_fn )
97+     return  data_loader 
0 commit comments