1- import os
2- import numpy as np
1+ from vocab import Vocabulary
2+ from model import EncoderCNN , DecoderRNN
3+ from configuration import Config
4+ from PIL import Image
5+ from torch .autograd import Variable
36import torch
4- import torchvision .transforms as T
5- import pickle
7+ import torchvision .transforms as T
68import matplotlib .pyplot as plt
7- from PIL import Image
8- from model import EncoderCNN , DecoderRNN
9- from vocab import Vocabulary
10- from torch .autograd import Variable
11-
12- # Image processing
13- transform = T .Compose ([
14- T .CenterCrop (224 ),
15- T .ToTensor (),
16- T .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ))])
17-
18- # Hyper Parameters
19- embed_size = 128
20- hidden_size = 512
21- num_layers = 1
22-
23- # Load vocabulary
24- with open ('./data/vocab.pkl' , 'rb' ) as f :
25- vocab = pickle .load (f )
26-
27- # Load an image array
28- images = os .listdir ('./data/train2014resized/' )
29- image_path = './data/train2014resized/' + images [12 ]
30- img = Image .open (image_path )
31- image = transform (img ).unsqueeze (0 )
9+ import numpy as np
10+ import argparse
11+ import pickle
12+ import os
3213
33- # Load the trained models
34- encoder = torch .load ('./encoder.pkl' )
35- decoder = torch .load ('./decoder.pkl' )
3614
37- # Encode the image
38- feature = encoder (Variable (image ).cuda ())
15+ def main (params ):
16+ # Configuration for hyper-parameters
17+ config = Config ()
18+
19+ # Image Preprocessing
20+ transform = config .test_transform
3921
40- # Set initial states
41- state = ( Variable ( torch . zeros ( num_layers , 1 , hidden_size ). cuda ()),
42- Variable ( torch . zeros ( num_layers , 1 , hidden_size )). cuda () )
22+ # Load vocabulary
23+ with open ( os . path . join ( config . vocab_path , 'vocab.pkl' ), 'rb' ) as f :
24+ vocab = pickle . load ( f )
4325
44- # Decode the feature to caption
45- ids = decoder .sample (feature , state )
26+ # Build Models
27+ encoder = EncoderCNN (config .embed_size )
28+ encoder .eval () # evaluation mode (BN uses moving mean/variance)
29+ decoder = DecoderRNN (config .embed_size , config .hidden_size ,
30+ len (vocab ), config .num_layers )
31+
4632
47- words = []
48- for id in ids :
49- word = vocab .idx2word [id .data [0 , 0 ]]
50- words .append (word )
51- if word == '<end>' :
52- break
53- caption = ' ' .join (words )
33+ # Load the trained model parameters
34+ encoder .load_state_dict (torch .load (os .path .join (config .model_path ,
35+ config .trained_encoder )))
36+ decoder .load_state_dict (torch .load (os .path .join (config .model_path ,
37+ config .trained_decoder )))
5438
55- # Display the image and generated caption
56- plt .imshow (img )
57- plt .show ()
58- print (caption )
39+ # Prepare Image
40+ image = Image .open (params ['image' ])
41+ image_tensor = Variable (transform (image ).unsqueeze (0 ))
42+
43+ # Set initial states
44+ state = (Variable (torch .zeros (config .num_layers , 1 , config .hidden_size )),
45+ Variable (torch .zeros (config .num_layers , 1 , config .hidden_size )))
46+
47+ # If use gpu
48+ if torch .cuda .is_available ():
49+ encoder .cuda ()
50+ decoder .cuda ()
51+ state = [s .cuda () for s in state ]
52+ image_tensor = image_tensor .cuda ()
53+
54+ # Generate caption from image
55+ feature = encoder (image_tensor )
56+ sampled_ids = decoder .sample (feature , state )
57+ sampled_ids = sampled_ids .cpu ().data .numpy ()
58+
59+ # Decode word_ids to words
60+ sampled_caption = []
61+ for word_id in sampled_ids :
62+ word = vocab .idx2word [word_id ]
63+ sampled_caption .append (word )
64+ if word == '<end>' :
65+ break
66+ sentence = ' ' .join (sampled_caption )
67+
68+ # Print out image and generated caption.
69+ print (sentence )
70+ plt .imshow (np .asarray (image ))
71+
72+ if __name__ == '__main__' :
73+ parser = argparse .ArgumentParser ()
74+ parser .add_argument ('--image' , type = str , required = True , help = 'image for generating caption' )
75+ args = parser .parse_args ()
76+ params = vars (args )
77+ main (params )
0 commit comments