Skip to content

Commit 8f002a5

Browse files
committed
add predict option to logistic regression
mend mend
1 parent b0dd8f0 commit 8f002a5

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

code/logistic_sgd.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import theano
4646
import theano.tensor as T
47+
from theano.gof import graph
4748

4849

4950
class LogisticRegression(object):
@@ -415,6 +416,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
415416
)
416417
)
417418

419+
# save the best model
420+
with open('best_model.pkl', 'w') as f:
421+
cPickle.dump(classifier, f)
422+
418423
if patience <= iter:
419424
done_looping = True
420425
break
@@ -433,5 +438,37 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
433438
os.path.split(__file__)[1] +
434439
' ran for %.1fs' % ((end_time - start_time)))
435440

441+
442+
def predict():
443+
"""
444+
An example of how to load a train model and use it
445+
to predict labels.
446+
"""
447+
448+
# load the saved model
449+
classifier = cPickle.load(open('best_model.pkl'))
450+
y_pred = classifier.y_pred
451+
452+
# find the input to theano graph
453+
inputs = graph.inputs([y_pred])
454+
# select only x
455+
inputs = [item for item in inputs if item.name == 'x']
456+
457+
# compile a predictor function
458+
predict_model = theano.function(
459+
inputs=inputs,
460+
outputs=y_pred)
461+
462+
# We can test it on some examples from test test
463+
dataset='mnist.pkl.gz'
464+
datasets = load_data(dataset)
465+
test_set_x, test_set_y = datasets[2]
466+
test_set_x = test_set_x.get_value()
467+
468+
predicted_values = predict_model(test_set_x[:10])
469+
print ("Predicted values for the first 10 examples in test set:")
470+
print predicted_values
471+
472+
436473
if __name__ == '__main__':
437474
sgd_optimization_mnist()

0 commit comments

Comments
 (0)