1- from data import get_loader
1+ from data import get_data_loader
22from vocab import Vocabulary
3+ from configuration import Config
34from model import EncoderCNN , DecoderRNN
45from torch .autograd import Variable
56from torch .nn .utils .rnn import pack_padded_sequence
67import torch
7- import torch .nn as nn
8- import numpy as np
8+ import torch .nn as nn
99import torchvision .transforms as T
10+ import numpy as np
1011import pickle
12+ import os
1113
12- # Hyper Parameters
13- num_epochs = 1
14- batch_size = 32
15- embed_size = 256
16- hidden_size = 512
17- crop_size = 224
18- num_layers = 1
19- learning_rate = 0.001
20- train_image_path = './data/train2014resized/'
21- train_json_path = './data/annotations/captions_train2014.json'
2214
23- # Image Preprocessing
24- transform = T .Compose ([
25- T .RandomCrop (crop_size ),
26- T .RandomHorizontalFlip (),
27- T .ToTensor (),
28- T .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ))])
15+ def main ():
16+ # Configuration for hyper-parameters
17+ config = Config ()
18+
19+ # Image preprocessing
20+ transform = T .Compose ([
21+ T .Scale (config .image_size ), # no resize
22+ T .RandomCrop (config .crop_size ),
23+ T .RandomHorizontalFlip (),
24+ T .ToTensor (),
25+ T .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))])
2926
30- # Load Vocabulary Wrapper
31- with open ('./data/ vocab.pkl' , 'rb' ) as f :
27+ # Load vocabulary wrapper
28+ with open (os . path . join ( config . vocab_path , ' vocab.pkl') , 'rb' ) as f :
3229 vocab = pickle .load (f )
33-
34- # Build Dataset Loader
35- train_loader = get_loader (train_image_path , train_json_path , vocab , transform ,
36- batch_size = batch_size , shuffle = True , num_workers = 2 )
37- total_step = len (train_loader )
3830
39- # Build Models
40- encoder = EncoderCNN (embed_size )
41- decoder = DecoderRNN (embed_size , hidden_size , len (vocab ), num_layers )
42- encoder .cuda ()
43- decoder .cuda ()
44-
45- # Loss and Optimizer
46- criterion = nn .CrossEntropyLoss ()
47- params = list (decoder .parameters ()) + list (encoder .resnet .fc .parameters ())
48- optimizer = torch .optim .Adam (params , lr = learning_rate )
31+ # Build data loader
32+ image_path = os .path .join (config .image_path , 'train2014' )
33+ json_path = os .path .join (config .caption_path , 'captions_train2014.json' )
34+ train_loader = get_data_loader (image_path , json_path , vocab ,
35+ transform , config .batch_size ,
36+ shuffle = True , num_workers = config .num_threads )
37+ total_step = len (train_loader )
38+
39+ # Build Models
40+ encoder = EncoderCNN (config .embed_size )
41+ decoder = DecoderRNN (config .embed_size , config .hidden_size ,
42+ len (vocab ), config .num_layers )
43+ encoder .cuda ()
44+ decoder .cuda ()
45+
46+ # Loss and Optimizer
47+ criterion = nn .CrossEntropyLoss ()
48+ params = list (decoder .parameters ()) + list (encoder .resnet .fc .parameters ())
49+ optimizer = torch .optim .Adam (params , lr = config .learning_rate )
50+
51+ # Train the Models
52+ for epoch in range (config .num_epochs ):
53+ for i , (images , captions , lengths ) in enumerate (train_loader ):
54+
55+ # Set mini-batch dataset
56+ images = Variable (images ).cuda ()
57+ captions = Variable (captions ).cuda ()
58+ targets = pack_padded_sequence (captions , lengths , batch_first = True )[0 ]
4959
50- # Train the Decoder
51- for epoch in range (num_epochs ):
52- for i , (images , captions , lengths ) in enumerate (train_loader ):
53- # Set mini-batch dataset
54- images = Variable (images ).cuda ()
55- captions = Variable (captions ).cuda ()
56- targets = pack_padded_sequence (captions , lengths , batch_first = True )[0 ]
57-
58- # Forward, Backward and Optimize
59- decoder .zero_grad ()
60- features = encoder (images )
61- outputs = decoder (features , captions , lengths )
62- loss = criterion (outputs , targets )
63- loss .backward ()
64- optimizer .step ()
65-
66- if i % 100 == 0 :
67- print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
68- % (epoch , num_epochs , i , total_step , loss .data [0 ], np .exp (loss .data [0 ])))
60+ # Forward, Backward and Optimize
61+ decoder .zero_grad ()
62+ encoder .zero_grad ()
63+ features = encoder (images )
64+ outputs = decoder (features , captions , lengths )
65+ loss = criterion (outputs , targets )
66+ loss .backward ()
67+ optimizer .step ()
6968
70- # Save the Model
71- torch .save (decoder , 'decoder.pkl' )
72- torch .save (encoder , 'encoder.pkl' )
69+ # Print log info
70+ if i % config .log_step == 0 :
71+ print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
72+ % (epoch , config .num_epochs , i , total_step ,
73+ loss .data [0 ], np .exp (loss .data [0 ])))
74+
75+ # Save the Model
76+ if (i + 1 ) % config .save_step == 0 :
77+ torch .save (decoder .state_dict (),
78+ os .path .join (config .model_path ,
79+ 'decoder-%d-%d.pkl' % (epoch + 1 , i + 1 )))
80+ torch .save (encoder .state_dict (),
81+ os .path .join (config .model_path ,
82+ 'encoder-%d-%d.pkl' % (epoch + 1 , i + 1 )))
83+ if __name__ == '__main__' :
84+ main ()
0 commit comments