Skip to content

Commit 4769688

Browse files
committed
modified the code
1 parent 0e59313 commit 4769688

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import nltk
2+
import pickle
3+
import argparse
4+
from collections import Counter
5+
from pycocotools.coco import COCO
6+
7+
8+
class Vocabulary(object):
9+
"""Simple vocabulary wrapper."""
10+
def __init__(self):
11+
self.word2idx = {}
12+
self.idx2word = {}
13+
self.idx = 0
14+
15+
def add_word(self, word):
16+
if not word in self.word2idx:
17+
self.word2idx[word] = self.idx
18+
self.idx2word[self.idx] = word
19+
self.idx += 1
20+
21+
def __call__(self, word):
22+
if not word in self.word2idx:
23+
return self.word2idx['<unk>']
24+
return self.word2idx[word]
25+
26+
def __len__(self):
27+
return len(self.word2idx)
28+
29+
def build_vocab(json, threshold):
30+
"""Build a simple vocabulary wrapper."""
31+
coco = COCO(json)
32+
counter = Counter()
33+
ids = coco.anns.keys()
34+
for i, id in enumerate(ids):
35+
caption = str(coco.anns[id]['caption'])
36+
tokens = nltk.tokenize.word_tokenize(caption.lower())
37+
counter.update(tokens)
38+
39+
if i % 1000 == 0:
40+
print("[%d/%d] Tokenized the captions." %(i, len(ids)))
41+
42+
# If the word frequency is less than 'threshold', then the word is discarded.
43+
words = [word for word, cnt in counter.items() if cnt >= threshold]
44+
45+
# Creates a vocab wrapper and add some special tokens.
46+
vocab = Vocabulary()
47+
vocab.add_word('<pad>')
48+
vocab.add_word('<start>')
49+
vocab.add_word('<end>')
50+
vocab.add_word('<unk>')
51+
52+
# Adds the words to the vocabulary.
53+
for i, word in enumerate(words):
54+
vocab.add_word(word)
55+
return vocab
56+
57+
def main(args):
58+
vocab = build_vocab(json=args.caption_path,
59+
threshold=args.threshold)
60+
vocab_path = args.vocab_path
61+
with open(vocab_path, 'wb') as f:
62+
pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
63+
print("Total vocabulary size: %d" %len(vocab))
64+
print("Saved the vocabulary wrapper to '%s'" %vocab_path)
65+
66+
67+
if __name__ == '__main__':
68+
parser = argparse.ArgumentParser()
69+
parser.add_argument('--caption_path', type=str,
70+
default='./data/annotations/captions_train2014.json',
71+
help='path for train annotation file')
72+
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
73+
help='path for saving vocabulary wrapper')
74+
parser.add_argument('--threshold', type=int, default=4,
75+
help='minimum word count threshold')
76+
args = parser.parse_args()
77+
main(args)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
import os
3+
from PIL import Image
4+
5+
6+
def resize_image(image, size):
7+
"""Resize an image to the given size."""
8+
return image.resize(size, Image.ANTIALIAS)
9+
10+
def resize_images(image_dir, output_dir, size):
11+
"""Resize the images in 'image_dir' and save into 'output_dir'."""
12+
if not os.path.exists(output_dir):
13+
os.makedirs(output_dir)
14+
15+
images = os.listdir(image_dir)
16+
num_images = len(images)
17+
for i, image in enumerate(images):
18+
with open(os.path.join(image_dir, image), 'r+b') as f:
19+
with Image.open(f) as img:
20+
img = resize_image(img, size)
21+
img.save(os.path.join(output_dir, image), img.format)
22+
if i % 100 == 0:
23+
print ("[%d/%d] Resized the images and saved into '%s'."
24+
%(i, num_images, output_dir))
25+
26+
def main(args):
27+
splits = ['train', 'val']
28+
for split in splits:
29+
image_dir = args.image_dir
30+
output_dir = args.output_dir
31+
image_size = [args.image_size, args.image_size]
32+
resize_images(image_dir, output_dir, image_size)
33+
34+
35+
if __name__ == '__main__':
36+
parser = argparse.ArgumentParser()
37+
parser.add_argument('--image_dir', type=str, default='./data/train2014/',
38+
help='directory for train images')
39+
parser.add_argument('--output_dir', type=str, default='./data/resized2014/',
40+
help='directory for saving resized images')
41+
parser.add_argument('--image_size', type=int, default=256,
42+
help='size for image after processing')
43+
args = parser.parse_args()
44+
main(args)

0 commit comments

Comments
 (0)