Skip to content

Commit 4fc2b1f

Browse files
committed
captioning modules are edited
1 parent 247de2d commit 4fc2b1f

File tree

5 files changed

+128
-109
lines changed

5 files changed

+128
-109
lines changed

tutorials/09 - Image Captioning/data.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
class CocoDataset(data.Dataset):
1414
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
1515
def __init__(self, root, json, vocab, transform=None):
16-
"""
16+
"""Set the path for images, captions and vocabulary wrapper.
17+
1718
Args:
1819
root: image directory.
1920
json: coco annotation file path.
2021
vocab: vocabulary wrapper.
21-
transform: transformer for image.
22+
transform: image transformer
2223
"""
2324
self.root = root
2425
self.coco = COCO(json)
@@ -27,7 +28,7 @@ def __init__(self, root, json, vocab, transform=None):
2728
self.transform = transform
2829

2930
def __getitem__(self, index):
30-
"""This function should return one data pair(image and caption)."""
31+
"""Returns one data pair (image and caption)."""
3132
coco = self.coco
3233
vocab = self.vocab
3334
ann_id = self.ids[index]
@@ -53,12 +54,13 @@ def __len__(self):
5354

5455

5556
def collate_fn(data):
56-
"""Build mini-batch tensors from a list of (image, caption) tuples.
57+
"""Creates mini-batch tensors from the list of tuples (image, caption).
58+
5759
Args:
58-
data: list of (image, caption) tuple.
60+
data: list of tuple (image, caption).
5961
- image: torch tensor of shape (3, 256, 256).
6062
- caption: torch tensor of shape (?); variable length.
61-
63+
6264
Returns:
6365
images: torch tensor of shape (batch_size, 3, 256, 256).
6466
targets: torch tensor of shape (batch_size, padded_length).
@@ -68,10 +70,10 @@ def collate_fn(data):
6870
data.sort(key=lambda x: len(x[1]), reverse=True)
6971
images, captions = zip(*data)
7072

71-
# Merge images (convert tuple of 3D tensor to 4D tensor)
73+
# Merge images (from tuple of 3D tensor to 4D tensor)
7274
images = torch.stack(images, 0)
7375

74-
# Merget captions (convert tuple of 1D tensor to 2D tensor)
76+
# Merge captions (from tuple of 1D tensor to 2D tensor)
7577
lengths = [len(cap) for cap in captions]
7678
targets = torch.zeros(len(captions), max(lengths)).long()
7779
for i, cap in enumerate(captions):
@@ -80,18 +82,18 @@ def collate_fn(data):
8082
return images, targets, lengths
8183

8284

83-
def get_loader(root, json, vocab, transform, batch_size=100, shuffle=True, num_workers=2):
85+
def get_data_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
8486
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
85-
# COCO custom dataset
87+
# COCO dataset
8688
coco = CocoDataset(root=root,
8789
json=json,
8890
vocab = vocab,
8991
transform=transform)
9092

91-
# Data loader
93+
# Data loader for COCO dataset
9294
data_loader = torch.utils.data.DataLoader(dataset=coco,
9395
batch_size=batch_size,
94-
shuffle=True,
96+
shuffle=shuffle,
9597
num_workers=num_workers,
9698
collate_fn=collate_fn)
9799
return data_loader

tutorials/09 - Image Captioning/model.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,60 @@
77

88
class EncoderCNN(nn.Module):
99
def __init__(self, embed_size):
10-
"""Load pretrained ResNet-152 and replace top fc layer."""
10+
"""Loads the pretrained ResNet-152 and replace top fc layer."""
1111
super(EncoderCNN, self).__init__()
1212
self.resnet = models.resnet152(pretrained=True)
13-
# For efficient memory usage.
1413
for param in self.resnet.parameters():
1514
param.requires_grad = False
1615
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
16+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
1717
self.init_weights()
18-
18+
1919
def init_weights(self):
20-
self.resnet.fc.weight.data.uniform_(-0.1, 0.1)
20+
"""Initialize weights."""
21+
self.resnet.fc.weight.data.normal_(0.0, 0.02)
2122
self.resnet.fc.bias.data.fill_(0)
2223

2324
def forward(self, images):
24-
"""Extract image feature vectors."""
25+
"""Extracts the image feature vectors."""
2526
features = self.resnet(images)
27+
features = self.bn(features)
2628
return features
2729

2830

2931
class DecoderRNN(nn.Module):
3032
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
31-
"""Set hyper-parameters and build layers."""
33+
"""Set the hyper-parameters and build the layers."""
3234
super(DecoderRNN, self).__init__()
33-
self.embed_size = embed_size
34-
self.hidden_size = hidden_size
35-
self.vocab_size = vocab_size
3635
self.embed = nn.Embedding(vocab_size, embed_size)
37-
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
36+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
3837
self.linear = nn.Linear(hidden_size, vocab_size)
38+
self.init_weights()
3939

4040
def init_weights(self):
41+
"""Initialize weights."""
4142
self.embed.weight.data.uniform_(-0.1, 0.1)
42-
self.linear.weigth.data.uniform_(-0.1, 0.1)
43+
self.linear.weight.data.uniform_(-0.1, 0.1)
4344
self.linear.bias.data.fill_(0)
4445

4546
def forward(self, features, captions, lengths):
46-
"""Decode image feature vectors and generate caption."""
47+
"""Decodes image feature vectors and generates captions."""
4748
embeddings = self.embed(captions)
4849
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
4950
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
5051
hiddens, _ = self.lstm(packed)
5152
outputs = self.linear(hiddens[0])
5253
return outputs
5354

54-
def sample(self, feature, state):
55-
"""Sample a caption for given a image feature."""
55+
def sample(self, features, states):
56+
"""Samples captions for given image features."""
5657
sampled_ids = []
57-
input = feature.unsqueeze(1)
58+
inputs = features.unsqueeze(1)
5859
for i in range(20):
59-
hidden, state = self.lstm(input, state) # (1, 1, hidden_size)
60-
output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size)
61-
predicted = output.max(1)[1]
60+
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
61+
outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size)
62+
predicted = outputs.max(1)[1]
6263
sampled_ids.append(predicted)
63-
input = self.embed(predicted)
64+
inputs = self.embed(predicted)
65+
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
6466
return sampled_ids

tutorials/09 - Image Captioning/resize.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
11
from PIL import Image
2+
from configuration import Config
23
import os
34

45

56
def resize_image(image, size):
6-
"""Resizes an image to the given size."""
7+
"""Resizes the image to the given size."""
78
return image.resize(size, Image.ANTIALIAS)
89

910
def resize_images(image_dir, output_dir, size):
10-
"""Resizes the images in the image_dir and save into the output_dir."""
11+
"""Resizes the images in 'image_dir' and save them in 'output_dir'."""
1112
if not os.path.exists(output_dir):
1213
os.makedirs(output_dir)
13-
1414
images = os.listdir(image_dir)
1515
num_images = len(images)
1616
for i, image in enumerate(images):
1717
with open(os.path.join(image_dir, image), 'r+b') as f:
1818
with Image.open(f) as img:
1919
img = resize_image(img, size)
20-
img.save(
21-
os.path.join(output_dir, image), img.format)
20+
img.save(os.path.join(output_dir, image), img.format)
2221
if i % 100 == 0:
23-
print ('[%d/%d] Resized the images and saved into %s.'
22+
print ('[%d/%d] Resized the images and saved them in %s.'
2423
%(i, num_images, output_dir))
2524

2625
def main():
26+
config = Config()
2727
splits = ['train', 'val']
2828
for split in splits:
29-
image_dir = './data/%s2014/' %split
30-
output_dir = './data/%s2014resized' %split
31-
resize_images(image_dir, output_dir, (256, 256))
29+
image_dir = os.path.join(config.image_path, '%s2014/' %split)
30+
output_dir = os.path.join(config.image_path, '%s2014resized' %split)
31+
resize_images(image_dir, output_dir, (config.image_size, config.image_size))
3232

3333

3434
if __name__ == '__main__':
Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,84 @@
1-
from data import get_loader
1+
from data import get_data_loader
22
from vocab import Vocabulary
3+
from configuration import Config
34
from model import EncoderCNN, DecoderRNN
45
from torch.autograd import Variable
56
from torch.nn.utils.rnn import pack_padded_sequence
67
import torch
7-
import torch.nn as nn
8-
import numpy as np
8+
import torch.nn as nn
99
import torchvision.transforms as T
10+
import numpy as np
1011
import pickle
12+
import os
1113

12-
# Hyper Parameters
13-
num_epochs = 1
14-
batch_size = 32
15-
embed_size = 256
16-
hidden_size = 512
17-
crop_size = 224
18-
num_layers = 1
19-
learning_rate = 0.001
20-
train_image_path = './data/train2014resized/'
21-
train_json_path = './data/annotations/captions_train2014.json'
2214

23-
# Image Preprocessing
24-
transform = T.Compose([
25-
T.RandomCrop(crop_size),
26-
T.RandomHorizontalFlip(),
27-
T.ToTensor(),
28-
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
15+
def main():
16+
# Configuration for hyper-parameters
17+
config = Config()
18+
19+
# Image preprocessing
20+
transform = T.Compose([
21+
T.Scale(config.image_size), # no resize
22+
T.RandomCrop(config.crop_size),
23+
T.RandomHorizontalFlip(),
24+
T.ToTensor(),
25+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
2926

30-
# Load Vocabulary Wrapper
31-
with open('./data/vocab.pkl', 'rb') as f:
27+
# Load vocabulary wrapper
28+
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
3229
vocab = pickle.load(f)
33-
34-
# Build Dataset Loader
35-
train_loader = get_loader(train_image_path, train_json_path, vocab, transform,
36-
batch_size=batch_size, shuffle=True, num_workers=2)
37-
total_step = len(train_loader)
3830

39-
# Build Models
40-
encoder = EncoderCNN(embed_size)
41-
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers)
42-
encoder.cuda()
43-
decoder.cuda()
44-
45-
# Loss and Optimizer
46-
criterion = nn.CrossEntropyLoss()
47-
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
48-
optimizer = torch.optim.Adam(params, lr=learning_rate)
31+
# Build data loader
32+
image_path = os.path.join(config.image_path, 'train2014')
33+
json_path = os.path.join(config.caption_path, 'captions_train2014.json')
34+
train_loader = get_data_loader(image_path, json_path, vocab,
35+
transform, config.batch_size,
36+
shuffle=True, num_workers=config.num_threads)
37+
total_step = len(train_loader)
38+
39+
# Build Models
40+
encoder = EncoderCNN(config.embed_size)
41+
decoder = DecoderRNN(config.embed_size, config.hidden_size,
42+
len(vocab), config.num_layers)
43+
encoder.cuda()
44+
decoder.cuda()
45+
46+
# Loss and Optimizer
47+
criterion = nn.CrossEntropyLoss()
48+
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
49+
optimizer = torch.optim.Adam(params, lr=config.learning_rate)
50+
51+
# Train the Models
52+
for epoch in range(config.num_epochs):
53+
for i, (images, captions, lengths) in enumerate(train_loader):
54+
55+
# Set mini-batch dataset
56+
images = Variable(images).cuda()
57+
captions = Variable(captions).cuda()
58+
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
4959

50-
# Train the Decoder
51-
for epoch in range(num_epochs):
52-
for i, (images, captions, lengths) in enumerate(train_loader):
53-
# Set mini-batch dataset
54-
images = Variable(images).cuda()
55-
captions = Variable(captions).cuda()
56-
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
57-
58-
# Forward, Backward and Optimize
59-
decoder.zero_grad()
60-
features = encoder(images)
61-
outputs = decoder(features, captions, lengths)
62-
loss = criterion(outputs, targets)
63-
loss.backward()
64-
optimizer.step()
65-
66-
if i % 100 == 0:
67-
print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
68-
%(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0])))
60+
# Forward, Backward and Optimize
61+
decoder.zero_grad()
62+
encoder.zero_grad()
63+
features = encoder(images)
64+
outputs = decoder(features, captions, lengths)
65+
loss = criterion(outputs, targets)
66+
loss.backward()
67+
optimizer.step()
6968

70-
# Save the Model
71-
torch.save(decoder, 'decoder.pkl')
72-
torch.save(encoder, 'encoder.pkl')
69+
# Print log info
70+
if i % config.log_step == 0:
71+
print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
72+
%(epoch, config.num_epochs, i, total_step,
73+
loss.data[0], np.exp(loss.data[0])))
74+
75+
# Save the Model
76+
if (i+1) % config.save_step == 0:
77+
torch.save(decoder.state_dict(),
78+
os.path.join(config.model_path,
79+
'decoder-%d-%d.pkl' %(epoch+1, i+1)))
80+
torch.save(encoder.state_dict(),
81+
os.path.join(config.model_path,
82+
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
83+
if __name__ == '__main__':
84+
main()

0 commit comments

Comments
 (0)