Skip to content

Commit 55db5cd

Browse files
committed
Merge pull request lisa-lab#92 from memimo/predict
Predict function for logistic regression.
2 parents 15c5442 + efa74a1 commit 55db5cd

File tree

5 files changed

+55
-0
lines changed

5 files changed

+55
-0
lines changed

code/convolutional_mlp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2, 2)):
110110
# store parameters of this layer
111111
self.params = [self.W, self.b]
112112

113+
# keep track of model input
114+
self.input = input
115+
113116

114117
def evaluate_lenet5(learning_rate=0.1, n_epochs=200,
115118
dataset='mnist.pkl.gz',

code/logistic_cg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __init__(self, input, n_in, n_out):
9797
# symbolic form
9898
self.y_pred = T.argmax(self.p_y_given_x, axis=1)
9999

100+
# keep track of model input
101+
self.input = input
102+
100103
def negative_log_likelihood(self, y):
101104
"""Return the negative log-likelihood of the prediction of this model
102105
under a given target distribution.

code/logistic_sgd.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def __init__(self, input, n_in, n_out):
109109
# parameters of the model
110110
self.params = [self.W, self.b]
111111

112+
# keep track of model input
113+
self.input = input
114+
112115
def negative_log_likelihood(self, y):
113116
"""Return the mean of the negative log-likelihood of the prediction
114117
of this model under a given target distribution.
@@ -415,6 +418,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
415418
)
416419
)
417420

421+
# save the best model
422+
with open('best_model.pkl', 'w') as f:
423+
cPickle.dump(classifier, f)
424+
418425
if patience <= iter:
419426
done_looping = True
420427
break
@@ -433,5 +440,31 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
433440
os.path.split(__file__)[1] +
434441
' ran for %.1fs' % ((end_time - start_time)))
435442

443+
444+
def predict():
445+
"""
446+
An example of how to load a trained model and use it
447+
to predict labels.
448+
"""
449+
450+
# load the saved model
451+
classifier = cPickle.load(open('best_model.pkl'))
452+
453+
# compile a predictor function
454+
predict_model = theano.function(
455+
inputs=[classifier.input],
456+
outputs=classifier.y_pred)
457+
458+
# We can test it on some examples from test test
459+
dataset='mnist.pkl.gz'
460+
datasets = load_data(dataset)
461+
test_set_x, test_set_y = datasets[2]
462+
test_set_x = test_set_x.get_value()
463+
464+
predicted_values = predict_model(test_set_x[:10])
465+
print ("Predicted values for the first 10 examples in test set:")
466+
print predicted_values
467+
468+
436469
if __name__ == '__main__':
437470
sgd_optimization_mnist()

code/mlp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def __init__(self, rng, input, n_in, n_hidden, n_out):
191191
self.params = self.hiddenLayer.params + self.logRegressionLayer.params
192192
# end-snippet-3
193193

194+
# keep track of model input
195+
self.input = input
196+
194197

195198
def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000,
196199
dataset='mnist.pkl.gz', batch_size=20, n_hidden=500):

doc/logreg.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,19 @@ approximately 1.936 epochs/sec and it took 75 epochs to reach a test
264264
error of 7.489%. On the GPU the code does almost 10.0 epochs/sec. For this
265265
instance we used a batch size of 600.
266266

267+
268+
Prediction Using a Trained Model
269+
+++++++++++++++++++++++++++++++
270+
271+
``sgd_optimization_mnist`` serialize and pickle the model each time new
272+
lowest validation error is reached. We can reload this model and predict
273+
labels of new data. ``predict`` function shows an example of how
274+
this could be done.
275+
276+
.. literalinclude:: ../code/logistic_sgd.py
277+
:pyobject: predict
278+
279+
267280
.. rubric:: Footnotes
268281

269282
.. [#f1] For smaller datasets and simpler models, more sophisticated descent

0 commit comments

Comments
 (0)