Skip to content

Commit 593c301

Browse files
committed
SOFTMAX 在 MNIST 数据集上的使用
Former-commit-id: abec548 Former-commit-id: a814943cfd15ec94b0f44ec6c01765215158f706 [formerly accd09e802e62429fc3a48085723cfdc60856faa] [formerly 5d3d056126ff935b2d7ddf863b7655170c5bc484 [formerly 52ffbd62e93672aaad8708431d756a125c371b5d]] [formerly 8ddf617c6d0eb3f9f4dd7f8e41c1d4bd19db70ab [formerly ed1afafca9300c52821eee5bc22c13842a8aba6f] [formerly 7110d4983637226002bae23b63daf7a9df322d45 [formerly a614ed1d6d59f1ff0ae1d0411d1c2dbb6830d227]]] [formerly bd0a127baa857df084995efc675ef4b7cf5dfb8e [formerly 1e276100df4ef79c355cb1f166131413ce50ce17] [formerly df311a5feaef6197c75e75c8a44a70c1ad1ec9dd [formerly b929dff4beee9f92cd1a9fb9e5d22074014c5ea6]] [formerly ce2ab28c89812a26e9e87a65d5a8b18bfea125fb [formerly 87e830468c8e3ca937fd2278d1afae7fe8815403] [formerly 757b7a8a9f8b70ab5a98cb2a28cbe6ce93b7ffb6 [formerly 0cc53af]]]] Former-commit-id: b011c4d634ca982747047e1ba70e4d3befff0b44 [formerly c365e94aa41ea859670c3804524c3ffcb2a1ce26] [formerly 5bc9b02b13b1ca3ab336cc827fb557e72a27dc27 [formerly ed5e172b762eb405d4d385bdd7ae93310a90d3fd]] [formerly f4e3cbca10e4dc4e1bcff898593d3944f350b194 [formerly f19b97ef290825faa6e08a9078051a0870bfb6a8] [formerly 0002d432c0a374ca8b3905e264d18485d3d82ec0 [formerly c82e9852a1b663d9135280d81e8e1f9c2c63c128]]] Former-commit-id: 0c565b5fd94a7b9ba60e740a1f421e936b1dcca1 [formerly 5b9436f84016b6026f8c29761ad5f3a93681215a] [formerly 8a69c01db08f17178fa60fb2ccb374bc24ba8113 [formerly 537081d21c19be4bfdddd837407caf43430135eb]] Former-commit-id: c9c005790d94654fd00a772df83d582f313620d0 [formerly c755b7f9a64684c4a0830ff195004db25c8dd08d] Former-commit-id: 209a7d653ade0c27eebcc6c47fb6dc6da329af85
1 parent df5f422 commit 593c301

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
/.ipynb_checkpoints/*
33
*-checkpoint.ipynb
44
*.pyc
5+
/*/mnist/*
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
e742f36bb55ffc69bcd84b3a7e8cb4d486a918a4

09. theano/download_mnist.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
import os.path
3+
import urllib
4+
import gzip
5+
import shutil
6+
7+
if not os.path.exists('mnist'):
8+
os.mkdir('mnist')
9+
10+
def download_and_gzip(name):
11+
if not os.path.exists(name + '.gz'):
12+
urllib.urlretrieve('http://yann.lecun.com/exdb/' + name + '.gz', name + '.gz')
13+
if not os.path.exists(name):
14+
with gzip.open(name + '.gz', 'rb') as f_in, open(name, 'wb') as f_out:
15+
shutil.copyfileobj(f_in, f_out)
16+
17+
download_and_gzip('mnist/train-images-idx3-ubyte')
18+
download_and_gzip('mnist/train-labels-idx1-ubyte')
19+
download_and_gzip('mnist/t10k-images-idx3-ubyte')
20+
download_and_gzip('mnist/t10k-labels-idx1-ubyte')

09. theano/load.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import os
3+
4+
datasets_dir = './'
5+
6+
def one_hot(x,n):
7+
if type(x) == list:
8+
x = np.array(x)
9+
x = x.flatten()
10+
o_h = np.zeros((len(x),n))
11+
o_h[np.arange(len(x)),x] = 1
12+
return o_h
13+
14+
def mnist(ntrain=60000,ntest=10000,onehot=True):
15+
data_dir = os.path.join(datasets_dir,'mnist/')
16+
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
17+
loaded = np.fromfile(file=fd,dtype=np.uint8)
18+
trX = loaded[16:].reshape((60000,28*28)).astype(float)
19+
20+
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
21+
loaded = np.fromfile(file=fd,dtype=np.uint8)
22+
trY = loaded[8:].reshape((60000))
23+
24+
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
25+
loaded = np.fromfile(file=fd,dtype=np.uint8)
26+
teX = loaded[16:].reshape((10000,28*28)).astype(float)
27+
28+
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
29+
loaded = np.fromfile(file=fd,dtype=np.uint8)
30+
teY = loaded[8:].reshape((10000))
31+
32+
trX = trX/255.
33+
teX = teX/255.
34+
35+
trX = trX[:ntrain]
36+
trY = trY[:ntrain]
37+
38+
teX = teX[:ntest]
39+
teY = teY[:ntest]
40+
41+
if onehot:
42+
trY = one_hot(trY, 10)
43+
teY = one_hot(teY, 10)
44+
else:
45+
trY = np.asarray(trY)
46+
teY = np.asarray(teY)
47+
48+
return trX,teX,trY,teY

0 commit comments

Comments
 (0)