|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
3 | 3 | import torchvision.models as models
|
4 |
| -import torch.nn.utils.rnn as rnn_utils |
| 4 | +from torch.nn.utils.rnn import pack_padded_sequence |
5 | 5 | from torch.autograd import Variable
|
6 | 6 |
|
7 | 7 |
|
@@ -31,27 +31,22 @@ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
31 | 31 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
|
32 | 32 | self.linear = nn.Linear(hidden_size, vocab_size)
|
33 | 33 |
|
34 |
| - def init_weights(self): |
35 |
| - pass |
36 |
| - |
37 | 34 | def forward(self, features, captions, lengths):
|
38 | 35 | """Decode image feature vectors and generate caption."""
|
39 | 36 | embeddings = self.embed(captions)
|
40 | 37 | embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
41 |
| - packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True) # lengths is ok |
| 38 | + packed = pack_padded_sequence(embeddings, lengths, batch_first=True) |
42 | 39 | hiddens, _ = self.lstm(packed)
|
43 | 40 | outputs = self.linear(hiddens[0])
|
44 | 41 | return outputs
|
45 | 42 |
|
46 | 43 | def sample(self, feature, state):
|
47 | 44 | """Sample a caption for given a image feature."""
|
48 |
| - # (batch_size, seq_length, embed_size) |
49 |
| - # features: (1, 128) |
50 | 45 | sampled_ids = []
|
51 | 46 | input = feature.unsqueeze(1)
|
52 | 47 | for i in range(20):
|
53 |
| - hidden, state = self.lstm(input, state) # (1, 1, 512) |
54 |
| - output = self.linear(hidden.view(-1, self.hidden_size)) # (1, 10000) |
| 48 | + hidden, state = self.lstm(input, state) # (1, 1, hidden_size) |
| 49 | + output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size) |
55 | 50 | predicted = output.max(1)[1]
|
56 | 51 | sampled_ids.append(predicted)
|
57 | 52 | input = self.embed(predicted)
|
|
0 commit comments