44
44
45
45
import theano
46
46
import theano .tensor as T
47
+ from theano .gof import graph
47
48
48
49
49
50
class LogisticRegression (object ):
@@ -415,6 +416,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
415
416
)
416
417
)
417
418
419
+ # save the best model
420
+ with open ('best_model.pkl' , 'w' ) as f :
421
+ cPickle .dump (classifier , f )
422
+
418
423
if patience <= iter :
419
424
done_looping = True
420
425
break
@@ -433,5 +438,37 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
433
438
os .path .split (__file__ )[1 ] +
434
439
' ran for %.1fs' % ((end_time - start_time )))
435
440
441
+
442
+ def predict ():
443
+ """
444
+ An example of how to load a train model and use it
445
+ to predict labels.
446
+ """
447
+
448
+ # load the saved model
449
+ classifier = cPickle .load (open ('best_model.pkl' ))
450
+ y_pred = classifier .y_pred
451
+
452
+ # find the input to theano graph
453
+ inputs = graph .inputs ([y_pred ])
454
+ # select only x
455
+ inputs = [item for item in inputs if item .name == 'x' ]
456
+
457
+ # compile a predictor function
458
+ predict_model = theano .function (
459
+ inputs = inputs ,
460
+ outputs = y_pred )
461
+
462
+ # We can test it on some examples from test test
463
+ dataset = 'mnist.pkl.gz'
464
+ datasets = load_data (dataset )
465
+ test_set_x , test_set_y = datasets [2 ]
466
+ test_set_x = test_set_x .get_value ()
467
+
468
+ predicted_values = predict_model (test_set_x [:10 ])
469
+ print ("Predicted values for the first 10 examples in test set:" )
470
+ print predicted_values
471
+
472
+
436
473
if __name__ == '__main__' :
437
474
sgd_optimization_mnist ()
0 commit comments