1+ import nltk
2+ import pickle
3+ import argparse
4+ from collections import Counter
5+ from pycocotools .coco import COCO
6+
7+
8+ class Vocabulary (object ):
9+ """Simple vocabulary wrapper."""
10+ def __init__ (self ):
11+ self .word2idx = {}
12+ self .idx2word = {}
13+ self .idx = 0
14+
15+ def add_word (self , word ):
16+ if not word in self .word2idx :
17+ self .word2idx [word ] = self .idx
18+ self .idx2word [self .idx ] = word
19+ self .idx += 1
20+
21+ def __call__ (self , word ):
22+ if not word in self .word2idx :
23+ return self .word2idx ['<unk>' ]
24+ return self .word2idx [word ]
25+
26+ def __len__ (self ):
27+ return len (self .word2idx )
28+
29+ def build_vocab (json , threshold ):
30+ """Build a simple vocabulary wrapper."""
31+ coco = COCO (json )
32+ counter = Counter ()
33+ ids = coco .anns .keys ()
34+ for i , id in enumerate (ids ):
35+ caption = str (coco .anns [id ]['caption' ])
36+ tokens = nltk .tokenize .word_tokenize (caption .lower ())
37+ counter .update (tokens )
38+
39+ if i % 1000 == 0 :
40+ print ("[%d/%d] Tokenized the captions." % (i , len (ids )))
41+
42+ # If the word frequency is less than 'threshold', then the word is discarded.
43+ words = [word for word , cnt in counter .items () if cnt >= threshold ]
44+
45+ # Creates a vocab wrapper and add some special tokens.
46+ vocab = Vocabulary ()
47+ vocab .add_word ('<pad>' )
48+ vocab .add_word ('<start>' )
49+ vocab .add_word ('<end>' )
50+ vocab .add_word ('<unk>' )
51+
52+ # Adds the words to the vocabulary.
53+ for i , word in enumerate (words ):
54+ vocab .add_word (word )
55+ return vocab
56+
57+ def main (args ):
58+ vocab = build_vocab (json = args .caption_path ,
59+ threshold = args .threshold )
60+ vocab_path = args .vocab_path
61+ with open (vocab_path , 'wb' ) as f :
62+ pickle .dump (vocab , f , pickle .HIGHEST_PROTOCOL )
63+ print ("Total vocabulary size: %d" % len (vocab ))
64+ print ("Saved the vocabulary wrapper to '%s'" % vocab_path )
65+
66+
67+ if __name__ == '__main__' :
68+ parser = argparse .ArgumentParser ()
69+ parser .add_argument ('--caption_path' , type = str ,
70+ default = './data/annotations/captions_train2014.json' ,
71+ help = 'path for train annotation file' )
72+ parser .add_argument ('--vocab_path' , type = str , default = './data/vocab.pkl' ,
73+ help = 'path for saving vocabulary wrapper' )
74+ parser .add_argument ('--threshold' , type = int , default = 4 ,
75+ help = 'minimum word count threshold' )
76+ args = parser .parse_args ()
77+ main (args )
0 commit comments