| 
1 | 1 | import torch  | 
2 | 2 | import torch.nn as nn  | 
3 | 3 | import torchvision.models as models  | 
4 |  | -import torch.nn.utils.rnn as rnn_utils  | 
 | 4 | +from torch.nn.utils.rnn import pack_padded_sequence  | 
5 | 5 | from torch.autograd import Variable  | 
6 | 6 | 
 
  | 
7 | 7 | 
 
  | 
@@ -31,27 +31,22 @@ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):  | 
31 | 31 |         self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)  | 
32 | 32 |         self.linear = nn.Linear(hidden_size, vocab_size)  | 
33 | 33 | 
 
  | 
34 |  | -    def init_weights(self):  | 
35 |  | -        pass  | 
36 |  | -          | 
37 | 34 |     def forward(self, features, captions, lengths):  | 
38 | 35 |         """Decode image feature vectors and generate caption."""  | 
39 | 36 |         embeddings = self.embed(captions)  | 
40 | 37 |         embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)  | 
41 |  | -        packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True) # lengths is ok  | 
 | 38 | +        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)   | 
42 | 39 |         hiddens, _ = self.lstm(packed)  | 
43 | 40 |         outputs = self.linear(hiddens[0])  | 
44 | 41 |         return outputs  | 
45 | 42 | 
 
  | 
46 | 43 |     def sample(self, feature, state):  | 
47 | 44 |         """Sample a caption for given a image feature."""  | 
48 |  | -        # (batch_size, seq_length, embed_size)  | 
49 |  | -        # features: (1, 128)  | 
50 | 45 |         sampled_ids = []  | 
51 | 46 |         input = feature.unsqueeze(1)  | 
52 | 47 |         for i in range(20):  | 
53 |  | -            hidden, state = self.lstm(input, state)  # (1, 1, 512)  | 
54 |  | -            output = self.linear(hidden.view(-1, self.hidden_size))  # (1, 10000)  | 
 | 48 | +            hidden, state = self.lstm(input, state)                  # (1, 1, hidden_size)  | 
 | 49 | +            output = self.linear(hidden.view(-1, self.hidden_size))  # (1, vocab_size)  | 
55 | 50 |             predicted = output.max(1)[1]  | 
56 | 51 |             sampled_ids.append(predicted)  | 
57 | 52 |             input = self.embed(predicted)  | 
 | 
0 commit comments