Skip to content

Commit f159685

Browse files
committed
model edited
1 parent 85d2e3e commit f159685

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tutorials/09 - Image Captioning/model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torchvision.models as models
4-
import torch.nn.utils.rnn as rnn_utils
4+
from torch.nn.utils.rnn import pack_padded_sequence
55
from torch.autograd import Variable
66

77

@@ -31,27 +31,22 @@ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
3131
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
3232
self.linear = nn.Linear(hidden_size, vocab_size)
3333

34-
def init_weights(self):
35-
pass
36-
3734
def forward(self, features, captions, lengths):
3835
"""Decode image feature vectors and generate caption."""
3936
embeddings = self.embed(captions)
4037
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
41-
packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True) # lengths is ok
38+
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
4239
hiddens, _ = self.lstm(packed)
4340
outputs = self.linear(hiddens[0])
4441
return outputs
4542

4643
def sample(self, feature, state):
4744
"""Sample a caption for given a image feature."""
48-
# (batch_size, seq_length, embed_size)
49-
# features: (1, 128)
5045
sampled_ids = []
5146
input = feature.unsqueeze(1)
5247
for i in range(20):
53-
hidden, state = self.lstm(input, state) # (1, 1, 512)
54-
output = self.linear(hidden.view(-1, self.hidden_size)) # (1, 10000)
48+
hidden, state = self.lstm(input, state) # (1, 1, hidden_size)
49+
output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size)
5550
predicted = output.max(1)[1]
5651
sampled_ids.append(predicted)
5752
input = self.embed(predicted)

0 commit comments

Comments
 (0)