1+ import os
2+ import numpy as np
3+ import torch
4+ import pickle
5+ import matplotlib .pyplot as plt
6+ from PIL import Image
7+ from model import EncoderCNN , DecoderRNN
8+ from vocab import Vocabulary
9+ from torch .autograd import Variable
10+
11+ # Hyper Parameters
12+ embed_size = 128
13+ hidden_size = 512
14+ num_layers = 1
15+
16+ # Load vocabulary
17+ with open ('./data/vocab.pkl' , 'rb' ) as f :
18+ vocab = pickle .load (f )
19+
20+ # Load an image array
21+ images = os .listdir ('./data/val2014resized/' )
22+ image_path = './data/val2014resized/' + images [12 ]
23+ with open (image_path , 'r+b' ) as f :
24+ img = np .asarray (Image .open (f ))
25+ image = torch .from_numpy (img .transpose (2 , 0 , 1 )).float ().unsqueeze (0 ) / 255 - 0.5
26+
27+ # Load the trained models
28+ encoder = torch .load ('./encoder.pkl' )
29+ decoder = torch .load ('./decoder.pkl' )
30+
31+ # Encode the image
32+ feature = encoder (Variable (image ).cuda ())
33+
34+ # Set initial states
35+ state = (Variable (torch .zeros (num_layers , 1 , hidden_size ).cuda ()),
36+ Variable (torch .zeros (num_layers , 1 , hidden_size )).cuda ())
37+
38+ # Decode the feature to caption
39+ ids = decoder .sample (feature , state )
40+
41+ words = []
42+ for id in ids :
43+ word = vocab .idx2word [id .data [0 , 0 ]]
44+ words .append (word )
45+ if word == '<end>' :
46+ break
47+ caption = ' ' .join (words )
48+
49+ # Display the image and generated caption
50+ plt .imshow (img )
51+ plt .show ()
52+ print (caption )
0 commit comments