Skip to content

Commit d2764f2

Browse files
committed
successfully ported logistic_sgd.py
1 parent 50b5010 commit d2764f2

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

code/logistic_sgd.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@
3232
Christopher M. Bishop, section 4.3.2
3333
3434
"""
35+
36+
from __future__ import print_function
37+
3538
__docformat__ = 'restructedtext en'
3639

37-
import cPickle
40+
import six.moves.cPickle as pickle
3841
import gzip
3942
import os
4043
import sys
@@ -194,19 +197,21 @@ def load_data(dataset):
194197
dataset = new_path
195198

196199
if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
197-
import urllib
200+
from six.moves import urllib
198201
origin = (
199202
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
200203
)
201-
print 'Downloading data from %s' % origin
202-
urllib.urlretrieve(origin, dataset)
204+
print('Downloading data from %s' % origin)
205+
urllib.request.urlretrieve(origin, dataset)
203206

204-
print '... loading data'
207+
print('... loading data')
205208

206209
# Load the dataset
207-
f = gzip.open(dataset, 'rb')
208-
train_set, valid_set, test_set = cPickle.load(f)
209-
f.close()
210+
with gzip.open(dataset, 'rb') as f:
211+
try:
212+
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
213+
except:
214+
train_set, valid_set, test_set = pickle.load(f)
210215
# train_set, valid_set, test_set format: tuple(input, target)
211216
# input is a numpy.ndarray of 2 dimensions (a matrix)
212217
# where each row corresponds to an example. target is a
@@ -276,14 +281,14 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
276281
test_set_x, test_set_y = datasets[2]
277282

278283
# compute number of minibatches for training, validation and testing
279-
n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
280-
n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
281-
n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size
284+
n_train_batches = train_set_x.get_value(borrow=True).shape[0] // batch_size
285+
n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] // batch_size
286+
n_test_batches = test_set_x.get_value(borrow=True).shape[0] // batch_size
282287

283288
######################
284289
# BUILD ACTUAL MODEL #
285290
######################
286-
print '... building the model'
291+
print('... building the model')
287292

288293
# allocate symbolic variables for the data
289294
index = T.lscalar() # index to a [mini]batch
@@ -348,14 +353,14 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
348353
###############
349354
# TRAIN MODEL #
350355
###############
351-
print '... training the model'
356+
print('... training the model')
352357
# early-stopping parameters
353358
patience = 5000 # look as this many examples regardless
354359
patience_increase = 2 # wait this much longer when a new best is
355360
# found
356361
improvement_threshold = 0.995 # a relative improvement of this much is
357362
# considered significant
358-
validation_frequency = min(n_train_batches, patience / 2)
363+
validation_frequency = min(n_train_batches, patience // 2)
359364
# go through this many
360365
# minibatche before checking the network
361366
# on the validation set; in this case we
@@ -369,7 +374,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
369374
epoch = 0
370375
while (epoch < n_epochs) and (not done_looping):
371376
epoch = epoch + 1
372-
for minibatch_index in xrange(n_train_batches):
377+
for minibatch_index in range(n_train_batches):
373378

374379
minibatch_avg_cost = train_model(minibatch_index)
375380
# iteration number
@@ -378,7 +383,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
378383
if (iter + 1) % validation_frequency == 0:
379384
# compute zero-one loss on validation set
380385
validation_losses = [validate_model(i)
381-
for i in xrange(n_valid_batches)]
386+
for i in range(n_valid_batches)]
382387
this_validation_loss = numpy.mean(validation_losses)
383388

384389
print(
@@ -402,7 +407,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
402407
# test it on the test set
403408

404409
test_losses = [test_model(i)
405-
for i in xrange(n_test_batches)]
410+
for i in range(n_test_batches)]
406411
test_score = numpy.mean(test_losses)
407412

408413
print(
@@ -419,8 +424,8 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
419424
)
420425

421426
# save the best model
422-
with open('best_model.pkl', 'w') as f:
423-
cPickle.dump(classifier, f)
427+
with open('best_model.pkl', 'wb') as f:
428+
pickle.dump(classifier, f)
424429

425430
if patience <= iter:
426431
done_looping = True
@@ -434,11 +439,11 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
434439
)
435440
% (best_validation_loss * 100., test_score * 100.)
436441
)
437-
print 'The code run for %d epochs, with %f epochs/sec' % (
438-
epoch, 1. * epoch / (end_time - start_time))
439-
print >> sys.stderr, ('The code for file ' +
440-
os.path.split(__file__)[1] +
441-
' ran for %.1fs' % ((end_time - start_time)))
442+
print('The code run for %d epochs, with %f epochs/sec' % (
443+
epoch, 1. * epoch / (end_time - start_time)))
444+
print(('The code for file ' +
445+
os.path.split(__file__)[1] +
446+
' ran for %.1fs' % ((end_time - start_time))), file=sys.stderr)
442447

443448

444449
def predict():
@@ -448,7 +453,7 @@ def predict():
448453
"""
449454

450455
# load the saved model
451-
classifier = cPickle.load(open('best_model.pkl'))
456+
classifier = pickle.load(open('best_model.pkl'))
452457

453458
# compile a predictor function
454459
predict_model = theano.function(
@@ -462,8 +467,8 @@ def predict():
462467
test_set_x = test_set_x.get_value()
463468

464469
predicted_values = predict_model(test_set_x[:10])
465-
print ("Predicted values for the first 10 examples in test set:")
466-
print predicted_values
470+
print("Predicted values for the first 10 examples in test set:")
471+
print(predicted_values)
467472

468473

469474
if __name__ == '__main__':

0 commit comments

Comments
 (0)