Skip to content

Commit 60c5c46

Browse files
committed
Added mnist.py
1 parent 3129d5a commit 60c5c46

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

mnist.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
########################################################################
2+
#
3+
# Downloads the MNIST data-set for recognizing hand-written digits.
4+
#
5+
# Implemented in Python 3.6
6+
#
7+
# Usage:
8+
# 1) Create a new object instance: data = MNIST(data_dir="data/MNIST/")
9+
# This automatically downloads the files to the given dir.
10+
# 2) Use the training-set as data.x_train, data.y_train and data.y_train_cls
11+
# 3) Get random batches of training data using data.random_batch()
12+
# 4) Use the test-set as data.x_test, data.y_test and data.y_test_cls
13+
#
14+
########################################################################
15+
#
16+
# This file is part of the TensorFlow Tutorials available at:
17+
#
18+
# https://github.com/Hvass-Labs/TensorFlow-Tutorials
19+
#
20+
# Published under the MIT License. See the file LICENSE for details.
21+
#
22+
# Copyright 2016-18 by Magnus Erik Hvass Pedersen
23+
#
24+
########################################################################
25+
26+
import numpy as np
27+
import gzip
28+
import os
29+
from dataset import one_hot_encoded
30+
from download import download
31+
32+
########################################################################
33+
34+
# Base URL for downloading the data-files from the internet.
35+
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
36+
37+
# Filenames for the data-set.
38+
filename_x_train = "train-images-idx3-ubyte.gz"
39+
filename_y_train = "train-labels-idx1-ubyte.gz"
40+
filename_x_test = "t10k-images-idx3-ubyte.gz"
41+
filename_y_test = "t10k-labels-idx1-ubyte.gz"
42+
43+
########################################################################
44+
45+
46+
class MNIST:
47+
"""
48+
The MNIST data-set for recognizing hand-written digits.
49+
This automatically downloads the data-files if they do
50+
not already exist in the local data_dir.
51+
52+
Note: Pixel-values are floats between 0.0 and 1.0.
53+
"""
54+
55+
# The images are 28 pixels in each dimension.
56+
img_size = 28
57+
58+
# The images are stored in one-dimensional arrays of this length.
59+
img_size_flat = img_size * img_size
60+
61+
# Tuple with height and width of images used to reshape arrays.
62+
img_shape = (img_size, img_size)
63+
64+
# Number of colour channels for the images: 1 channel for gray-scale.
65+
num_channels = 1
66+
67+
# Tuple with height, width and depth used to reshape arrays.
68+
# This is used for reshaping in Keras.
69+
img_shape_full = (img_size, img_size, num_channels)
70+
71+
# Number of classes, one class for each of 10 digits.
72+
num_classes = 10
73+
74+
def __init__(self, data_dir="data/MNIST/"):
75+
"""
76+
Load the MNIST data-set. Automatically downloads the files
77+
if they do not already exist locally.
78+
79+
:param data_dir: Base-directory for downloading files.
80+
"""
81+
82+
# Copy args to self.
83+
self.data_dir = data_dir
84+
85+
# Number of images in each sub-set.
86+
self.num_train = 55000
87+
self.num_val = 5000
88+
self.num_test = 10000
89+
90+
# Download / load the training-set.
91+
x_train = self._load_images(filename=filename_x_train)
92+
y_train_cls = self._load_cls(filename=filename_y_train)
93+
94+
# Split the training-set into train / validation.
95+
# Pixel-values are converted from ints between 0 and 255
96+
# to floats between 0.0 and 1.0.
97+
self.x_train = x_train[0:self.num_train] / 255.0
98+
self.x_val = x_train[self.num_train:] / 255.0
99+
self.y_train_cls = y_train_cls[0:self.num_train]
100+
self.y_val_cls = y_train_cls[self.num_train:]
101+
102+
# Download / load the test-set.
103+
self.x_test = self._load_images(filename=filename_x_test) / 255.0
104+
self.y_test_cls = self._load_cls(filename=filename_y_test)
105+
106+
# Convert the class-numbers from bytes to ints as that is needed
107+
# some places in TensorFlow.
108+
self.y_train_cls = self.y_train_cls.astype(np.int)
109+
self.y_val_cls = self.y_val_cls.astype(np.int)
110+
self.y_test_cls = self.y_test_cls.astype(np.int)
111+
112+
# Convert the integer class-numbers into one-hot encoded arrays.
113+
self.y_train = one_hot_encoded(class_numbers=self.y_train_cls,
114+
num_classes=self.num_classes)
115+
self.y_val = one_hot_encoded(class_numbers=self.y_val_cls,
116+
num_classes=self.num_classes)
117+
self.y_test = one_hot_encoded(class_numbers=self.y_test_cls,
118+
num_classes=self.num_classes)
119+
120+
def _load_data(self, filename, offset):
121+
"""
122+
Load the data in the given file. Automatically downloads the file
123+
if it does not already exist in the data_dir.
124+
125+
:param filename: Name of the data-file.
126+
:param offset: Start offset in bytes when reading the data-file.
127+
:return: The data as a numpy array.
128+
"""
129+
130+
# Download the file from the internet if it does not exist locally.
131+
download(base_url=base_url, filename=filename, download_dir=self.data_dir)
132+
133+
# Read the data-file.
134+
path = os.path.join(self.data_dir, filename)
135+
with gzip.open(path, 'rb') as f:
136+
data = np.frombuffer(f.read(), np.uint8, offset=offset)
137+
138+
return data
139+
140+
def _load_images(self, filename):
141+
"""
142+
Load image-data from the given file.
143+
Automatically downloads the file if it does not exist locally.
144+
145+
:param filename: Name of the data-file.
146+
:return: Numpy array.
147+
"""
148+
149+
# Read the data as one long array of bytes.
150+
data = self._load_data(filename=filename, offset=16)
151+
152+
# Reshape to 2-dim array with shape (num_images, img_size_flat).
153+
images_flat = data.reshape(-1, self.img_size_flat)
154+
155+
return images_flat
156+
157+
def _load_cls(self, filename):
158+
"""
159+
Load class-numbers from the given file.
160+
Automatically downloads the file if it does not exist locally.
161+
162+
:param filename: Name of the data-file.
163+
:return: Numpy array.
164+
"""
165+
return self._load_data(filename=filename, offset=8)
166+
167+
def random_batch(self, batch_size=32):
168+
"""
169+
Create a random batch of training-data.
170+
171+
:param batch_size: Number of images in the batch.
172+
:return: 3 numpy arrays (x, y, y_cls)
173+
"""
174+
175+
# Create a random index into the training-set.
176+
idx = np.random.randint(low=0, high=self.num_train, size=batch_size)
177+
178+
# Use the index to lookup random training-data.
179+
x_batch = self.x_train[idx]
180+
y_batch = self.y_train[idx]
181+
y_batch_cls = self.y_train_cls[idx]
182+
183+
return x_batch, y_batch, y_batch_cls
184+
185+
186+
########################################################################

0 commit comments

Comments
 (0)