Skip to content

Commit 6f5fda1

Browse files
committed
image captioning completed'
1 parent ba7d546 commit 6f5fda1

File tree

7 files changed

+297
-68
lines changed

7 files changed

+297
-68
lines changed

tutorials/09 - Image Captioning/configuration.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torchvision.transforms as T
2+
3+
14
class Config(object):
25
"""Wrapper class for hyper-parameters."""
36
def __init__(self):
@@ -8,6 +11,21 @@ def __init__(self):
811
self.word_count_threshold = 4
912
self.num_threads = 2
1013

14+
# Image preprocessing in training phase
15+
self.train_transform = T.Compose([
16+
T.Scale(self.image_size),
17+
T.RandomCrop(self.crop_size),
18+
T.RandomHorizontalFlip(),
19+
T.ToTensor(),
20+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
21+
22+
# Image preprocessing in test phase
23+
self.test_transform = T.Compose([
24+
T.Scale(self.crop_size),
25+
T.CenterCrop(self.crop_size),
26+
T.ToTensor(),
27+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
28+
1129
# Training
1230
self.num_epochs = 5
1331
self.batch_size = 64
@@ -23,4 +41,7 @@ def __init__(self):
2341
# Path
2442
self.image_path = './data/'
2543
self.caption_path = './data/annotations/'
26-
self.vocab_path = './data/'
44+
self.vocab_path = './data/'
45+
self.model_path = './model/'
46+
self.trained_encoder = 'encoder-4-6000.pkl'
47+
self.trained_decoder = 'decoder-4-6000.pkl'

tutorials/09 - Image Captioning/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import torchvision.transforms as transforms
33
import torch.utils.data as data
44
import os
5+
import sys
56
import pickle
67
import numpy as np
78
import nltk
89
from PIL import Image
910
from vocab import Vocabulary
11+
sys.path.append('../../../coco/PythonAPI')
1012
from pycocotools.coco import COCO
1113

1214

tutorials/09 - Image Captioning/evaluate_model.ipynb

Lines changed: 177 additions & 0 deletions
Large diffs are not rendered by default.

tutorials/09 - Image Captioning/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def forward(self, features, captions, lengths):
5353
return outputs
5454

5555
def sample(self, features, states):
56-
"""Samples captions for given image features."""
56+
"""Samples captions for given image features (Greedy search)."""
5757
sampled_ids = []
5858
inputs = features.unsqueeze(1)
5959
for i in range(20):
6060
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
61-
outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size)
61+
outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
6262
predicted = outputs.max(1)[1]
6363
sampled_ids.append(predicted)
6464
inputs = self.embed(predicted)
65-
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
66-
return sampled_ids
65+
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
66+
return sampled_ids.squeeze()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
matplotlib==2.0.0
2+
nltk==3.2.2
3+
numpy==1.12.0
4+
Pillow==4.0.0
5+
argparse
Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,77 @@
1-
import os
2-
import numpy as np
1+
from vocab import Vocabulary
2+
from model import EncoderCNN, DecoderRNN
3+
from configuration import Config
4+
from PIL import Image
5+
from torch.autograd import Variable
36
import torch
4-
import torchvision.transforms as T
5-
import pickle
7+
import torchvision.transforms as T
68
import matplotlib.pyplot as plt
7-
from PIL import Image
8-
from model import EncoderCNN, DecoderRNN
9-
from vocab import Vocabulary
10-
from torch.autograd import Variable
11-
12-
# Image processing
13-
transform = T.Compose([
14-
T.CenterCrop(224),
15-
T.ToTensor(),
16-
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
17-
18-
# Hyper Parameters
19-
embed_size = 128
20-
hidden_size = 512
21-
num_layers = 1
22-
23-
# Load vocabulary
24-
with open('./data/vocab.pkl', 'rb') as f:
25-
vocab = pickle.load(f)
26-
27-
# Load an image array
28-
images = os.listdir('./data/train2014resized/')
29-
image_path = './data/train2014resized/' + images[12]
30-
img = Image.open(image_path)
31-
image = transform(img).unsqueeze(0)
9+
import numpy as np
10+
import argparse
11+
import pickle
12+
import os
3213

33-
# Load the trained models
34-
encoder = torch.load('./encoder.pkl')
35-
decoder = torch.load('./decoder.pkl')
3614

37-
# Encode the image
38-
feature = encoder(Variable(image).cuda())
15+
def main(params):
16+
# Configuration for hyper-parameters
17+
config = Config()
18+
19+
# Image Preprocessing
20+
transform = config.test_transform
3921

40-
# Set initial states
41-
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
42-
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
22+
# Load vocabulary
23+
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
24+
vocab = pickle.load(f)
4325

44-
# Decode the feature to caption
45-
ids = decoder.sample(feature, state)
26+
# Build Models
27+
encoder = EncoderCNN(config.embed_size)
28+
encoder.eval() # evaluation mode (BN uses moving mean/variance)
29+
decoder = DecoderRNN(config.embed_size, config.hidden_size,
30+
len(vocab), config.num_layers)
31+
4632

47-
words = []
48-
for id in ids:
49-
word = vocab.idx2word[id.data[0, 0]]
50-
words.append(word)
51-
if word == '<end>':
52-
break
53-
caption = ' '.join(words)
33+
# Load the trained model parameters
34+
encoder.load_state_dict(torch.load(os.path.join(config.model_path,
35+
config.trained_encoder)))
36+
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
37+
config.trained_decoder)))
5438

55-
# Display the image and generated caption
56-
plt.imshow(img)
57-
plt.show()
58-
print (caption)
39+
# Prepare Image
40+
image = Image.open(params['image'])
41+
image_tensor = Variable(transform(image).unsqueeze(0))
42+
43+
# Set initial states
44+
state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
45+
Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
46+
47+
# If use gpu
48+
if torch.cuda.is_available():
49+
encoder.cuda()
50+
decoder.cuda()
51+
state = [s.cuda() for s in state]
52+
image_tensor = image_tensor.cuda()
53+
54+
# Generate caption from image
55+
feature = encoder(image_tensor)
56+
sampled_ids = decoder.sample(feature, state)
57+
sampled_ids = sampled_ids.cpu().data.numpy()
58+
59+
# Decode word_ids to words
60+
sampled_caption = []
61+
for word_id in sampled_ids:
62+
word = vocab.idx2word[word_id]
63+
sampled_caption.append(word)
64+
if word == '<end>':
65+
break
66+
sentence = ' '.join(sampled_caption)
67+
68+
# Print out image and generated caption.
69+
print (sentence)
70+
plt.imshow(np.asarray(image))
71+
72+
if __name__ == '__main__':
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
75+
args = parser.parse_args()
76+
params = vars(args)
77+
main(params)

tutorials/09 - Image Captioning/train.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch.nn.utils.rnn import pack_padded_sequence
77
import torch
88
import torch.nn as nn
9-
import torchvision.transforms as T
109
import numpy as np
1110
import pickle
1211
import os
@@ -16,14 +15,13 @@ def main():
1615
# Configuration for hyper-parameters
1716
config = Config()
1817

18+
# Create model directory
19+
if not os.path.exists(config.model_path):
20+
os.makedirs(config.model_path)
21+
1922
# Image preprocessing
20-
transform = T.Compose([
21-
T.Scale(config.image_size), # no resize
22-
T.RandomCrop(config.crop_size),
23-
T.RandomHorizontalFlip(),
24-
T.ToTensor(),
25-
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
26-
23+
transform = config.train_transform
24+
2725
# Load vocabulary wrapper
2826
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
2927
vocab = pickle.load(f)
@@ -40,22 +38,28 @@ def main():
4038
encoder = EncoderCNN(config.embed_size)
4139
decoder = DecoderRNN(config.embed_size, config.hidden_size,
4240
len(vocab), config.num_layers)
43-
encoder.cuda()
44-
decoder.cuda()
41+
42+
if torch.cuda.is_available()
43+
encoder.cuda()
44+
decoder.cuda()
4545

4646
# Loss and Optimizer
4747
criterion = nn.CrossEntropyLoss()
4848
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
4949
optimizer = torch.optim.Adam(params, lr=config.learning_rate)
50-
50+
5151
# Train the Models
5252
for epoch in range(config.num_epochs):
5353
for i, (images, captions, lengths) in enumerate(train_loader):
5454

5555
# Set mini-batch dataset
56-
images = Variable(images).cuda()
57-
captions = Variable(captions).cuda()
56+
images = Variable(images)
57+
captions = Variable(captions)
5858
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
59+
60+
if torch.cuda.is_available():
61+
images = images.cuda()
62+
captions = captions.cuda()
5963

6064
# Forward, Backward and Optimize
6165
decoder.zero_grad()
@@ -80,5 +84,6 @@ def main():
8084
torch.save(encoder.state_dict(),
8185
os.path.join(config.model_path,
8286
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
87+
8388
if __name__ == '__main__':
8489
main()

0 commit comments

Comments
 (0)