1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import pickle
5
+ import matplotlib .pyplot as plt
6
+ from PIL import Image
7
+ from model import EncoderCNN , DecoderRNN
8
+ from vocab import Vocabulary
9
+ from torch .autograd import Variable
10
+
11
+ # Hyper Parameters
12
+ embed_size = 128
13
+ hidden_size = 512
14
+ num_layers = 1
15
+
16
+ # Load vocabulary
17
+ with open ('./data/vocab.pkl' , 'rb' ) as f :
18
+ vocab = pickle .load (f )
19
+
20
+ # Load an image array
21
+ images = os .listdir ('./data/val2014resized/' )
22
+ image_path = './data/val2014resized/' + images [12 ]
23
+ with open (image_path , 'r+b' ) as f :
24
+ img = np .asarray (Image .open (f ))
25
+ image = torch .from_numpy (img .transpose (2 , 0 , 1 )).float ().unsqueeze (0 ) / 255 - 0.5
26
+
27
+ # Load the trained models
28
+ encoder = torch .load ('./encoder.pkl' )
29
+ decoder = torch .load ('./decoder.pkl' )
30
+
31
+ # Encode the image
32
+ feature = encoder (Variable (image ).cuda ())
33
+
34
+ # Set initial states
35
+ state = (Variable (torch .zeros (num_layers , 1 , hidden_size ).cuda ()),
36
+ Variable (torch .zeros (num_layers , 1 , hidden_size )).cuda ())
37
+
38
+ # Decode the feature to caption
39
+ ids = decoder .sample (feature , state )
40
+
41
+ words = []
42
+ for id in ids :
43
+ word = vocab .idx2word [id .data [0 , 0 ]]
44
+ words .append (word )
45
+ if word == '<end>' :
46
+ break
47
+ caption = ' ' .join (words )
48
+
49
+ # Display the image and generated caption
50
+ plt .imshow (img )
51
+ plt .show ()
52
+ print (caption )
0 commit comments