Skip to content

Commit 76680e4

Browse files
authored
Merge pull request #1 from capimx/Project
Merge files for Bert
2 parents 4167bce + 1e78a90 commit 76680e4

File tree

7 files changed

+124
-30
lines changed

7 files changed

+124
-30
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{\rtf1\ansi\ansicpg1252\cocoartf1671\cocoasubrtf200
2+
{\fonttbl\f0\fswiss\fcharset0 Helvetica;}
3+
{\colortbl;\red255\green255\blue255;}
4+
{\*\expandedcolortbl;;}
5+
\paperw11900\paperh16840\margl1440\margr1440\vieww10800\viewh8400\viewkind0
6+
\pard\tx566\tx1133\tx1700\tx2267\tx2834\tx3401\tx3968\tx4535\tx5102\tx5669\tx6236\tx6803\pardirnatural\partightenfactor0
7+
8+
\f0\fs24 \cf0 import os\
9+
os.chdir('pytorch-tutorial/tutorials/03-advanced/image_captioning/')\
10+
\
11+
\
12+
import nltk\
13+
import gensim\
14+
nltk.download('punkt')\
15+
\
16+
\
17+
!pip install bert-embedding\
18+
!pip install https://github.com/dmlc/gluon-nlp/tarball/master\
19+
!pip install mxnet-cu100\
20+
\
21+
\
22+
os.chdir('/content')\
23+
\
24+
\
25+
%%shell\
26+
git clone https://github.com/pdollar/coco.git\
27+
cd coco/PythonAPI/\
28+
make\
29+
python setup.py build\
30+
python setup.py install\
31+
cd ../../\
32+
git clone https://github.com/capimx/pytorch-tutorial.git\
33+
\
34+
\
35+
#Changing directories\
36+
os.chdir('pytorch-tutorial/tutorials/03-advanced/image_captioning/')\
37+
!sed -i 's/unzip /unzip -q /g' download.sh #Make quiet unzip\
38+
\
39+
\
40+
%%time\
41+
%%shell\
42+
pip install -r requirements.txt\
43+
chmod +x download.sh\
44+
./download.sh\
45+
\
46+
\
47+
########### REPLACE ALL THE DOWNLOADED FILES WITH OURS ###########\
48+
############### UPLOAD IN THE SAME FOLDER BERT EMBS ###############\
49+
\
50+
\
51+
%%time\
52+
!python build_vocab.py\
53+
\
54+
\
55+
%%time\
56+
!python resize.py\
57+
\
58+
\
59+
%%time\
60+
!python train.py}

tutorials/03-advanced/image_captioning/build_vocab.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,32 @@
66

77

88
class Vocabulary(object):
9-
"""Simple vocabulary wrapper."""
9+
1010
def __init__(self):
1111
self.word2idx = {}
1212
self.idx2word = {}
1313
self.idx = 0
14-
14+
1515
def add_word(self, word):
1616
if not word in self.word2idx:
1717
self.word2idx[word] = self.idx
1818
self.idx2word[self.idx] = word
1919
self.idx += 1
20-
20+
2121
def __call__(self, word):
2222
if not word in self.word2idx:
2323
return self.word2idx['<unk>']
2424
return self.word2idx[word]
25-
25+
2626
def __len__(self):
2727
return len(self.word2idx)
28-
28+
29+
def __keys__(self):
30+
iterable = []
31+
for key in self.word2idx:
32+
iterable.append(key)
33+
return iterable
34+
2935
def build_vocab(json, threshold):
3036
"""Build a simple vocabulary wrapper."""
3137
coco = COCO(json)
@@ -73,4 +79,4 @@ def main(args):
7379
parser.add_argument('--threshold', type=int, default=4,
7480
help='minimum word count threshold')
7581
args = parser.parse_args()
76-
main(args)
82+
main(args)

tutorials/03-advanced/image_captioning/data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
102102
shuffle=shuffle,
103103
num_workers=num_workers,
104104
collate_fn=collate_fn)
105-
return data_loader
105+
return data_loader

tutorials/03-advanced/image_captioning/model.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torch
2+
import gensim
3+
import numpy as np
24
import torch.nn as nn
35
import torchvision.models as models
6+
from bert_embedding import BertEmbedding
47
from torch.nn.utils.rnn import pack_padded_sequence
58

6-
79
class EncoderCNN(nn.Module):
810
def __init__(self, embed_size):
911
"""Load the pretrained ResNet-152 and replace top fc layer."""
@@ -24,12 +26,34 @@ def forward(self, images):
2426

2527

2628
class DecoderRNN(nn.Module):
27-
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
29+
def __init__(self, embed_size, hidden_size, vocab, num_layers, max_seq_length=20):
2830
"""Set the hyper-parameters and build the layers."""
2931
super(DecoderRNN, self).__init__()
30-
self.embed = nn.Embedding(vocab_size, embed_size)
32+
Bert_file = "bert-base-uncased.30522.768d.vec"
33+
print("M1")
34+
Lookup = gensim.models.KeyedVectors.load_word2vec_format(Bert_file, binary=False)
35+
bert_embedding = BertEmbedding()
36+
Embed = np.zeros((len(vocab), embed_size))
37+
print("M2")
38+
Embed[vocab('<pad>'),:] = np.random.normal(0, 1, embed_size)
39+
Embed[vocab('<start>'),:] = np.random.normal(0, 1, embed_size)
40+
Embed[vocab('<end>'),:] = np.random.normal(0, 1, embed_size)
41+
Embed[vocab('<unk>'),:] = np.random.normal(0, 1, embed_size)
42+
print("M3")
43+
for word in vocab.__keys__()[4:]:
44+
try:
45+
Embed[vocab(word),:] = Lookup[word]
46+
except:
47+
bert_word = word
48+
token = bert_word.split('\n')
49+
pred = bert_embedding(token)
50+
Embed[vocab(word),:] = pred[0][1][0]
51+
52+
print("M4")
53+
self.embed = nn.Embedding(len(vocab), embed_size)
54+
self.embed.weight.data.copy_(torch.FloatTensor(Embed))
3155
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
32-
self.linear = nn.Linear(hidden_size, vocab_size)
56+
self.linear = nn.Linear(hidden_size, len(vocab))
3357
self.max_seg_length = max_seq_length
3458

3559
def forward(self, features, captions, lengths):
@@ -47,10 +71,10 @@ def sample(self, features, states=None):
4771
inputs = features.unsqueeze(1)
4872
for i in range(self.max_seg_length):
4973
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
50-
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
74+
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, len(vocab))
5175
_, predicted = outputs.max(1) # predicted: (batch_size)
5276
sampled_ids.append(predicted)
5377
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
5478
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
5579
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
56-
return sampled_ids
80+
return sampled_ids

tutorials/03-advanced/image_captioning/resize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def main(args):
3939
parser.add_argument('--image_size', type=int, default=256,
4040
help='size for image after processing')
4141
args = parser.parse_args()
42-
main(args)
42+
main(args)

tutorials/03-advanced/image_captioning/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main(args):
3535

3636
# Build models
3737
encoder = EncoderCNN(args.embed_size).eval() # eval mode (batchnorm uses moving mean/variance)
38-
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
38+
decoder = DecoderRNN(args.embed_size, args.hidden_size, vocab, args.num_layers)
3939
encoder = encoder.to(device)
4040
decoder = decoder.to(device)
4141

tutorials/03-advanced/image_captioning/train.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import argparse
2-
import torch
3-
import torch.nn as nn
4-
import numpy as np
51
import os
2+
import torch
63
import pickle
4+
import argparse
5+
import numpy as np
6+
import torch.nn as nn
77
from data_loader import get_loader
88
from build_vocab import Vocabulary
9+
from torchvision import transforms
910
from model import EncoderCNN, DecoderRNN
1011
from torch.nn.utils.rnn import pack_padded_sequence
11-
from torchvision import transforms
1212
import datetime
1313

1414
# Device configuration
1515
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1616
now = datetime.datetime.now()
1717
dir = '/content/drive/My Drive/NLPResults/'
1818

19+
1920
def append_progress(line):
2021
filename = "Progress" + str(now.day) +'-'+ str(now.hour) + str(now.minute) + str(now.second) + ".txt"
2122
with open(dir + filename, 'a') as f:
@@ -47,18 +48,21 @@ def main(args):
4748
# Build data loader
4849
data_loader = get_loader(args.image_dir, args.caption_path, vocab,
4950
transform, args.batch_size,
50-
shuffle=True, num_workers=args.num_workers)
51-
51+
shuffle=True, num_workers=args.num_workers)
52+
print('Data loaded!\n')
53+
5254
# Build the models
5355
encoder = EncoderCNN(args.embed_size).to(device)
54-
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
56+
decoder = DecoderRNN(args.embed_size, args.hidden_size, vocab, args.num_layers).to(device)
5557

5658
# Loss and optimizer
59+
print('Encoder and decoder initialised!\n')
5760
criterion = nn.CrossEntropyLoss()
5861
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
5962
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
6063

6164
# Train the models
65+
print('Start training:\n')
6266
total_step = len(data_loader)
6367
for epoch in range(args.num_epochs):
6468
for i, (images, captions, lengths) in enumerate(data_loader):
@@ -80,20 +84,20 @@ def main(args):
8084
# Print log info
8185
if i % args.log_step == 0:
8286
log_info = 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item()))
87+
print(log_info)
8388
append_progress(log_info)
84-
print(log_info)
8589

8690
# Save the model checkpoints
8791
if (i+1) % args.save_step == 0:
8892
torch.save(decoder.state_dict(), os.path.join(
8993
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
9094
torch.save(encoder.state_dict(), os.path.join(
9195
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
92-
#Save to DRive
96+
#Save to Drive
9397
torch.save(decoder.state_dict(), os.path.join(
94-
dir, 'decoder-{}-{}-{}.ckpt'.format(epoch+1, i+1,now)))
98+
dir, 'decoder-{}-{}-{}.ckpt'.format(epoch+1, i+1, now)))
9599
torch.save(encoder.state_dict(), os.path.join(
96-
dir, 'encoder-{}-{}-{}.ckpt'.format(epoch+1, i+1,now)))
100+
dir, 'encoder-{}-{}-{}.ckpt'.format(epoch+1, i+1, now)))
97101

98102

99103
if __name__ == '__main__':
@@ -107,8 +111,8 @@ def main(args):
107111
parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models')
108112

109113
# Model parameters
110-
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
111-
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
114+
parser.add_argument('--embed_size', type=int , default=768, help='dimension of word embedding vectors')
115+
parser.add_argument('--hidden_size', type=int , default=1536, help='dimension of lstm hidden states')
112116
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
113117

114118
parser.add_argument('--num_epochs', type=int, default=5)
@@ -117,4 +121,4 @@ def main(args):
117121
parser.add_argument('--learning_rate', type=float, default=0.001)
118122
args = parser.parse_args()
119123
print(args)
120-
main(args)
124+
main(args)

0 commit comments

Comments
 (0)