Skip to content

Commit 85d2e3e

Browse files
committed
image caption generation added
1 parent 77f4aa8 commit 85d2e3e

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import pickle
5+
import matplotlib.pyplot as plt
6+
from PIL import Image
7+
from model import EncoderCNN, DecoderRNN
8+
from vocab import Vocabulary
9+
from torch.autograd import Variable
10+
11+
# Hyper Parameters
12+
embed_size = 128
13+
hidden_size = 512
14+
num_layers = 1
15+
16+
# Load vocabulary
17+
with open('./data/vocab.pkl', 'rb') as f:
18+
vocab = pickle.load(f)
19+
20+
# Load an image array
21+
images = os.listdir('./data/val2014resized/')
22+
image_path = './data/val2014resized/' + images[12]
23+
with open(image_path, 'r+b') as f:
24+
img = np.asarray(Image.open(f))
25+
image = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255 - 0.5
26+
27+
# Load the trained models
28+
encoder = torch.load('./encoder.pkl')
29+
decoder = torch.load('./decoder.pkl')
30+
31+
# Encode the image
32+
feature = encoder(Variable(image).cuda())
33+
34+
# Set initial states
35+
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
36+
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
37+
38+
# Decode the feature to caption
39+
ids = decoder.sample(feature, state)
40+
41+
words = []
42+
for id in ids:
43+
word = vocab.idx2word[id.data[0, 0]]
44+
words.append(word)
45+
if word == '<end>':
46+
break
47+
caption = ' '.join(words)
48+
49+
# Display the image and generated caption
50+
plt.imshow(img)
51+
plt.show()
52+
print (caption)

0 commit comments

Comments
 (0)