@@ -32,40 +32,55 @@ def get_hog_features(trainset):
32
32
def Predict (testset ,trainset ,train_labels ):
33
33
predict = []
34
34
count = 0
35
+
35
36
for test_vec in testset :
37
+ # 输出当前运行的测试用例坐标,用于测试
36
38
print count
37
39
count += 1
38
40
39
- knn_list = []
41
+ knn_list = [] # 当前k个最近邻居
42
+ max_index = - 1 # 当前k个最近邻居中距离最远点的坐标
43
+ max_dist = 0 # 当前k个最近邻居中距离最远点的距离
40
44
41
- for i in range (len (train_labels )):
45
+ # 先将前k个点放入k个最近邻居中,填充满knn_list
46
+ for i in range (k ):
42
47
label = train_labels [i ]
43
48
train_vec = trainset [i ]
44
49
45
- dist = np .linalg .norm (train_vec - test_vec )
50
+ dist = np .linalg .norm (train_vec - test_vec ) # 计算两个点的欧氏距离
46
51
47
- if len (knn_list ) < k : # 如果还不够10个邻近点则直接添加即可
48
- knn_list .append ((dist ,label ))
49
- else :
50
- max_index = - 1
51
- max_dist = dist
52
+ knn_list .append ((dist ,label ))
53
+
54
+ # 剩下的点
55
+ for i in range (k ,len (train_labels )):
56
+ label = train_labels [i ]
57
+ train_vec = trainset [i ]
58
+
59
+ dist = np .linalg .norm (train_vec - test_vec ) # 计算两个点的欧氏距离
52
60
53
- # 寻找10个邻近点钟距离最远的点
61
+ # 寻找10个邻近点钟距离最远的点
62
+ if max_index < 0 :
54
63
for j in range (k ):
55
64
if max_dist < knn_list [j ][0 ]:
56
65
max_index = j
57
66
max_dist = knn_list [max_index ][0 ]
58
67
59
- if max_index >= 0 :
60
- knn_list [max_index ] = (dist ,label )
68
+ # 如果当前k个最近邻居中存在点距离比当前点距离远,则替换
69
+ if dist < max_dist :
70
+ knn_list [max_index ] = (dist ,label )
71
+ max_index = - 1
72
+
61
73
74
+ # 统计选票
62
75
class_total = 10
63
76
class_count = [0 for i in range (class_total )]
64
77
for dist ,label in knn_list :
65
78
class_count [label ] += 1
66
79
80
+ # 找出最大选票
67
81
mmax = max (class_count )
68
82
83
+ # 找出最大选票标签
69
84
for i in range (class_total ):
70
85
if mmax == class_count [i ]:
71
86
predict .append (i )
0 commit comments