Skip to content

Commit 373dfbf

Browse files
committed
add inputs the classifer
1 parent fc762a7 commit 373dfbf

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

code/logistic_sgd.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
import theano
4646
import theano.tensor as T
47-
from theano.gof import graph
4847

4948

5049
class LogisticRegression(object):
@@ -110,6 +109,9 @@ def __init__(self, input, n_in, n_out):
110109
# parameters of the model
111110
self.params = [self.W, self.b]
112111

112+
# keep track of model input
113+
self.input = input
114+
113115
def negative_log_likelihood(self, y):
114116
"""Return the mean of the negative log-likelihood of the prediction
115117
of this model under a given target distribution.
@@ -447,17 +449,11 @@ def predict():
447449

448450
# load the saved model
449451
classifier = cPickle.load(open('best_model.pkl'))
450-
y_pred = classifier.y_pred
451-
452-
# find the input to theano graph
453-
inputs = graph.inputs([y_pred])
454-
# select only x
455-
inputs = [item for item in inputs if item.name == 'x']
456452

457453
# compile a predictor function
458454
predict_model = theano.function(
459-
inputs=inputs,
460-
outputs=y_pred)
455+
inputs=[classifier.input],
456+
outputs=classifier.y_pred)
461457

462458
# We can test it on some examples from test test
463459
dataset='mnist.pkl.gz'

doc/logreg.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ labels of new data. ``predict`` function shows an example of how
274274
this could be done.
275275

276276
.. literalinclude:: ../code/logistic_sgd.py
277-
:start-after: ' ran for %.1fs' % ((end_time - start_time)))
278-
:end-before: if __name__ == '__main__':
277+
:pyobject: predict
279278

280279

281280
.. rubric:: Footnotes

0 commit comments

Comments
 (0)