1- import argparse
2- import torch
3- import torch .nn as nn
4- import numpy as np
51import os
2+ import torch
63import pickle
4+ import argparse
5+ import numpy as np
6+ import torch .nn as nn
77from data_loader import get_loader
88from build_vocab import Vocabulary
9+ from torchvision import transforms
910from model import EncoderCNN , DecoderRNN
1011from torch .nn .utils .rnn import pack_padded_sequence
11- from torchvision import transforms
1212import datetime
1313
1414# Device configuration
1515device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
1616now = datetime .datetime .now ()
1717dir = '/content/drive/My Drive/NLPResults/'
1818
19+
1920def append_progress (line ):
2021 filename = "Progress" + str (now .day ) + '-' + str (now .hour ) + str (now .minute ) + str (now .second ) + ".txt"
2122 with open (dir + filename , 'a' ) as f :
@@ -47,18 +48,21 @@ def main(args):
4748 # Build data loader
4849 data_loader = get_loader (args .image_dir , args .caption_path , vocab ,
4950 transform , args .batch_size ,
50- shuffle = True , num_workers = args .num_workers )
51-
51+ shuffle = True , num_workers = args .num_workers )
52+ print ('Data loaded!\n ' )
53+
5254 # Build the models
5355 encoder = EncoderCNN (args .embed_size ).to (device )
54- decoder = DecoderRNN (args .embed_size , args .hidden_size , len ( vocab ) , args .num_layers ).to (device )
56+ decoder = DecoderRNN (args .embed_size , args .hidden_size , vocab , args .num_layers ).to (device )
5557
5658 # Loss and optimizer
59+ print ('Encoder and decoder initialised!\n ' )
5760 criterion = nn .CrossEntropyLoss ()
5861 params = list (decoder .parameters ()) + list (encoder .linear .parameters ()) + list (encoder .bn .parameters ())
5962 optimizer = torch .optim .Adam (params , lr = args .learning_rate )
6063
6164 # Train the models
65+ print ('Start training:\n ' )
6266 total_step = len (data_loader )
6367 for epoch in range (args .num_epochs ):
6468 for i , (images , captions , lengths ) in enumerate (data_loader ):
@@ -80,20 +84,20 @@ def main(args):
8084 # Print log info
8185 if i % args .log_step == 0 :
8286 log_info = 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}' .format (epoch , args .num_epochs , i , total_step , loss .item (), np .exp (loss .item ()))
87+ print (log_info )
8388 append_progress (log_info )
84- print (log_info )
8589
8690 # Save the model checkpoints
8791 if (i + 1 ) % args .save_step == 0 :
8892 torch .save (decoder .state_dict (), os .path .join (
8993 args .model_path , 'decoder-{}-{}.ckpt' .format (epoch + 1 , i + 1 )))
9094 torch .save (encoder .state_dict (), os .path .join (
9195 args .model_path , 'encoder-{}-{}.ckpt' .format (epoch + 1 , i + 1 )))
92- #Save to DRive
96+ #Save to Drive
9397 torch .save (decoder .state_dict (), os .path .join (
94- dir , 'decoder-{}-{}-{}.ckpt' .format (epoch + 1 , i + 1 ,now )))
98+ dir , 'decoder-{}-{}-{}.ckpt' .format (epoch + 1 , i + 1 , now )))
9599 torch .save (encoder .state_dict (), os .path .join (
96- dir , 'encoder-{}-{}-{}.ckpt' .format (epoch + 1 , i + 1 ,now )))
100+ dir , 'encoder-{}-{}-{}.ckpt' .format (epoch + 1 , i + 1 , now )))
97101
98102
99103if __name__ == '__main__' :
@@ -107,8 +111,8 @@ def main(args):
107111 parser .add_argument ('--save_step' , type = int , default = 1000 , help = 'step size for saving trained models' )
108112
109113 # Model parameters
110- parser .add_argument ('--embed_size' , type = int , default = 256 , help = 'dimension of word embedding vectors' )
111- parser .add_argument ('--hidden_size' , type = int , default = 512 , help = 'dimension of lstm hidden states' )
114+ parser .add_argument ('--embed_size' , type = int , default = 768 , help = 'dimension of word embedding vectors' )
115+ parser .add_argument ('--hidden_size' , type = int , default = 1536 , help = 'dimension of lstm hidden states' )
112116 parser .add_argument ('--num_layers' , type = int , default = 1 , help = 'number of layers in lstm' )
113117
114118 parser .add_argument ('--num_epochs' , type = int , default = 5 )
@@ -117,4 +121,4 @@ def main(args):
117121 parser .add_argument ('--learning_rate' , type = float , default = 0.001 )
118122 args = parser .parse_args ()
119123 print (args )
120- main (args )
124+ main (args )
0 commit comments