Skip to content

Commit 75d31f7

Browse files
committed
first virsion
1 parent a5a91dc commit 75d31f7

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

decision_tree/decision_tree.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
#encoding=utf-8
2+
3+
import cv2
4+
import time
5+
import math
6+
import numpy as np
7+
import pandas as pd
8+
9+
10+
from sklearn.cross_validation import train_test_split
11+
from sklearn.metrics import accuracy_score
12+
13+
total_class = 10
14+
15+
# 二值化
16+
def binaryzation(img):
17+
cv_img = img.astype(np.uint8)
18+
cv2.threshold(cv_img,50,1,cv2.cv.CV_THRESH_BINARY_INV,cv_img)
19+
return cv_img
20+
21+
def binaryzation_features(trainset):
22+
features = []
23+
24+
for img in trainset:
25+
img = np.reshape(img,(28,28))
26+
cv_img = img.astype(np.uint8)
27+
28+
img_b = binaryzation(cv_img)
29+
# hog_feature = np.transpose(hog_feature)
30+
features.append(img_b)
31+
32+
features = np.array(features)
33+
features = np.reshape(features,(-1,784))
34+
35+
return features
36+
37+
38+
class Tree(object):
39+
def __init__(self,node_type,Class = None, feature = None):
40+
self.node_type = node_type
41+
self.dict = {}
42+
self.Class = Class
43+
self.feature = feature
44+
45+
def add_tree(self,val,tree):
46+
self.dict[val] = tree
47+
48+
def predict(self,features):
49+
if self.node_type == 'leaf':
50+
return self.Class
51+
52+
print 'in'
53+
54+
tree = self.dict[features[self.feature]]
55+
return tree.predict(features)
56+
57+
def calc_ent(x):
58+
"""
59+
calculate shanno ent of x
60+
"""
61+
62+
x_value_list = set([x[i] for i in range(x.shape[0])])
63+
ent = 0.0
64+
for x_value in x_value_list:
65+
p = float(x[x == x_value].shape[0]) / x.shape[0]
66+
logp = np.log2(p)
67+
ent -= p * logp
68+
69+
return ent
70+
71+
def calc_condition_ent(x, y):
72+
"""
73+
calculate ent H(y|x)
74+
"""
75+
76+
# calc ent(y|x)
77+
x_value_list = set([x[i] for i in range(x.shape[0])])
78+
ent = 0.0
79+
for x_value in x_value_list:
80+
sub_y = y[x == x_value]
81+
temp_ent = calc_ent(sub_y)
82+
ent += (float(sub_y.shape[0]) / y.shape[0]) * temp_ent
83+
84+
return ent
85+
86+
def calc_ent_grap(x,y):
87+
"""
88+
calculate ent grap
89+
"""
90+
91+
base_ent = calc_ent(y)
92+
condition_ent = calc_condition_ent(x, y)
93+
ent_grap = base_ent - condition_ent
94+
95+
return ent_grap
96+
97+
def train(train_set,train_label,features,epsilon):
98+
global total_class
99+
100+
LEAF = 'leaf'
101+
INTERNAL = 'internal'
102+
103+
104+
# 步骤1——如果train_set中的所有实例都属于同一类Ck
105+
label_dict = [0 for i in xrange(total_class)]
106+
for label in train_label:
107+
label_dict[label] += 1
108+
109+
for label, label_count in enumerate(label_dict):
110+
if label_count == len(train_label):
111+
tree = Tree(LEAF,Class = label)
112+
return tree
113+
114+
# 步骤2——如果features为空
115+
max_len,max_class = 0,0
116+
for i in xrange(total_class):
117+
class_i = filter(lambda x:x==i,train_label)
118+
if len(class_i) > max_len:
119+
max_class = i
120+
max_len = len(class_i)
121+
122+
if len(features) == 0:
123+
tree = Tree(LEAF,Class = max_class)
124+
return tree
125+
126+
# 步骤3——计算信息增益
127+
max_feature = 0
128+
max_gda = 0
129+
130+
D = train_label
131+
HD = calc_ent(D)
132+
for feature in features:
133+
A = np.array(train_set[:,feature].flat)
134+
gda = HD - calc_condition_ent(A,D)
135+
136+
if gda > max_gda:
137+
max_gda,max_feature = gda,feature
138+
139+
# 步骤4——小于阈值
140+
if max_gda < epsilon:
141+
tree = Tree(LEAF,Class = max_class)
142+
return tree
143+
144+
# 步骤5——构建非空子集
145+
sub_features = filter(lambda x:x!=max_feature,features)
146+
tree = Tree(INTERNAL,feature=max_feature)
147+
148+
feature_col = np.array(train_set[:,max_feature].flat)
149+
feature_value_list = set([feature_col[i] for i in range(feature_col.shape[0])])
150+
for feature_value in feature_value_list:
151+
152+
index = []
153+
for i in xrange(len(train_label)):
154+
if train_set[i][max_feature] == feature_value:
155+
index.append(i)
156+
157+
sub_train_set = train_set[index]
158+
sub_train_label = train_label[index]
159+
160+
sub_tree = train(sub_train_set,sub_train_label,sub_features,epsilon)
161+
tree.add_tree(feature_value,sub_tree)
162+
163+
return tree
164+
165+
def predict(test_set,tree):
166+
167+
result = []
168+
for features in test_set:
169+
tmp_predict = tree.predict(features)
170+
result.append(tmp_predict)
171+
return np.array(result)
172+
173+
174+
175+
if __name__ == '__main__':
176+
# classes = [0,0,1,1,0,0,0,1,1,1,1,1,1,1,0]
177+
#
178+
# age = [0,0,0,0,0,1,1,1,1,1,2,2,2,2,2]
179+
# occupation = [0,0,1,1,0,0,0,1,0,0,0,0,1,1,0]
180+
# house = [0,0,0,1,0,0,0,1,1,1,1,1,0,0,0]
181+
# loan = [0,1,1,0,0,0,1,1,2,2,2,1,1,2,0]
182+
#
183+
# features = []
184+
#
185+
# for i in range(15):
186+
# feature = [age[i],occupation[i],house[i],loan[i]]
187+
# features.append(feature)
188+
#
189+
# trainset = np.array(features)
190+
#
191+
# tree = train(trainset,np.array(classes),[0,1,2,3],0.1)
192+
#
193+
# print type(tree)
194+
# features = [0,0,0,1]
195+
# print tree.predict(np.array(features))
196+
197+
198+
print 'Start read data'
199+
200+
time_1 = time.time()
201+
202+
raw_data = pd.read_csv('../data/train.csv',header=0)
203+
data = raw_data.values
204+
205+
imgs = data[0::,1::]
206+
labels = data[::,0]
207+
208+
features = binaryzation_features(imgs)
209+
210+
# 选取 2/3 数据作为训练集, 1/3 数据作为测试集
211+
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=23323)
212+
# print train_features.shape
213+
# print train_features.shape
214+
215+
time_2 = time.time()
216+
print 'read data cost ',time_2 - time_1,' second','\n'
217+
218+
print 'Start training'
219+
tree = train(train_features,train_labels,[i for i in range(784)],0.2)
220+
print type(tree)
221+
print 'knn do not need to train'
222+
time_3 = time.time()
223+
print 'training cost ',time_3 - time_2,' second','\n'
224+
225+
print 'Start predicting'
226+
test_predict = predict(test_features,tree)
227+
time_4 = time.time()
228+
print 'predicting cost ',time_4 - time_3,' second','\n'
229+
230+
score = accuracy_score(test_labels,test_predict)
231+
print "The accruacy socre is ", score
232+
233+
234+
235+
236+
237+
238+
239+
240+
241+
242+

0 commit comments

Comments
 (0)