32
32
Christopher M. Bishop, section 4.3.2
33
33
34
34
"""
35
+
36
+ from __future__ import print_function
37
+
35
38
__docformat__ = 'restructedtext en'
36
39
37
- import cPickle
40
+ import six . moves . cPickle as pickle
38
41
import gzip
39
42
import os
40
43
import sys
@@ -194,19 +197,21 @@ def load_data(dataset):
194
197
dataset = new_path
195
198
196
199
if (not os .path .isfile (dataset )) and data_file == 'mnist.pkl.gz' :
197
- import urllib
200
+ from six . moves import urllib
198
201
origin = (
199
202
'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
200
203
)
201
- print 'Downloading data from %s' % origin
202
- urllib .urlretrieve (origin , dataset )
204
+ print ( 'Downloading data from %s' % origin )
205
+ urllib .request . urlretrieve (origin , dataset )
203
206
204
- print '... loading data'
207
+ print ( '... loading data' )
205
208
206
209
# Load the dataset
207
- f = gzip .open (dataset , 'rb' )
208
- train_set , valid_set , test_set = cPickle .load (f )
209
- f .close ()
210
+ with gzip .open (dataset , 'rb' ) as f :
211
+ try :
212
+ train_set , valid_set , test_set = pickle .load (f , encoding = 'latin1' )
213
+ except :
214
+ train_set , valid_set , test_set = pickle .load (f )
210
215
# train_set, valid_set, test_set format: tuple(input, target)
211
216
# input is a numpy.ndarray of 2 dimensions (a matrix)
212
217
# where each row corresponds to an example. target is a
@@ -276,14 +281,14 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
276
281
test_set_x , test_set_y = datasets [2 ]
277
282
278
283
# compute number of minibatches for training, validation and testing
279
- n_train_batches = train_set_x .get_value (borrow = True ).shape [0 ] / batch_size
280
- n_valid_batches = valid_set_x .get_value (borrow = True ).shape [0 ] / batch_size
281
- n_test_batches = test_set_x .get_value (borrow = True ).shape [0 ] / batch_size
284
+ n_train_batches = train_set_x .get_value (borrow = True ).shape [0 ] // batch_size
285
+ n_valid_batches = valid_set_x .get_value (borrow = True ).shape [0 ] // batch_size
286
+ n_test_batches = test_set_x .get_value (borrow = True ).shape [0 ] // batch_size
282
287
283
288
######################
284
289
# BUILD ACTUAL MODEL #
285
290
######################
286
- print '... building the model'
291
+ print ( '... building the model' )
287
292
288
293
# allocate symbolic variables for the data
289
294
index = T .lscalar () # index to a [mini]batch
@@ -348,14 +353,14 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
348
353
###############
349
354
# TRAIN MODEL #
350
355
###############
351
- print '... training the model'
356
+ print ( '... training the model' )
352
357
# early-stopping parameters
353
358
patience = 5000 # look as this many examples regardless
354
359
patience_increase = 2 # wait this much longer when a new best is
355
360
# found
356
361
improvement_threshold = 0.995 # a relative improvement of this much is
357
362
# considered significant
358
- validation_frequency = min (n_train_batches , patience / 2 )
363
+ validation_frequency = min (n_train_batches , patience // 2 )
359
364
# go through this many
360
365
# minibatche before checking the network
361
366
# on the validation set; in this case we
@@ -369,7 +374,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
369
374
epoch = 0
370
375
while (epoch < n_epochs ) and (not done_looping ):
371
376
epoch = epoch + 1
372
- for minibatch_index in xrange (n_train_batches ):
377
+ for minibatch_index in range (n_train_batches ):
373
378
374
379
minibatch_avg_cost = train_model (minibatch_index )
375
380
# iteration number
@@ -378,7 +383,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
378
383
if (iter + 1 ) % validation_frequency == 0 :
379
384
# compute zero-one loss on validation set
380
385
validation_losses = [validate_model (i )
381
- for i in xrange (n_valid_batches )]
386
+ for i in range (n_valid_batches )]
382
387
this_validation_loss = numpy .mean (validation_losses )
383
388
384
389
print (
@@ -402,7 +407,7 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
402
407
# test it on the test set
403
408
404
409
test_losses = [test_model (i )
405
- for i in xrange (n_test_batches )]
410
+ for i in range (n_test_batches )]
406
411
test_score = numpy .mean (test_losses )
407
412
408
413
print (
@@ -419,8 +424,8 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
419
424
)
420
425
421
426
# save the best model
422
- with open ('best_model.pkl' , 'w ' ) as f :
423
- cPickle .dump (classifier , f )
427
+ with open ('best_model.pkl' , 'wb ' ) as f :
428
+ pickle .dump (classifier , f )
424
429
425
430
if patience <= iter :
426
431
done_looping = True
@@ -434,11 +439,11 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
434
439
)
435
440
% (best_validation_loss * 100. , test_score * 100. )
436
441
)
437
- print 'The code run for %d epochs, with %f epochs/sec' % (
438
- epoch , 1. * epoch / (end_time - start_time ))
439
- print >> sys . stderr , ('The code for file ' +
440
- os .path .split (__file__ )[1 ] +
441
- ' ran for %.1fs' % ((end_time - start_time )))
442
+ print ( 'The code run for %d epochs, with %f epochs/sec' % (
443
+ epoch , 1. * epoch / (end_time - start_time )))
444
+ print ( ('The code for file ' +
445
+ os .path .split (__file__ )[1 ] +
446
+ ' ran for %.1fs' % ((end_time - start_time ))), file = sys . stderr )
442
447
443
448
444
449
def predict ():
@@ -448,7 +453,7 @@ def predict():
448
453
"""
449
454
450
455
# load the saved model
451
- classifier = cPickle .load (open ('best_model.pkl' ))
456
+ classifier = pickle .load (open ('best_model.pkl' ))
452
457
453
458
# compile a predictor function
454
459
predict_model = theano .function (
@@ -462,8 +467,8 @@ def predict():
462
467
test_set_x = test_set_x .get_value ()
463
468
464
469
predicted_values = predict_model (test_set_x [:10 ])
465
- print ("Predicted values for the first 10 examples in test set:" )
466
- print predicted_values
470
+ print ("Predicted values for the first 10 examples in test set:" )
471
+ print ( predicted_values )
467
472
468
473
469
474
if __name__ == '__main__' :
0 commit comments