Skip to content

Commit 14bd7c8

Browse files
committed
knn optimize top k select time
1 parent e4a14bb commit 14bd7c8

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

knn/knn.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,40 +32,55 @@ def get_hog_features(trainset):
3232
def Predict(testset,trainset,train_labels):
3333
predict = []
3434
count = 0
35+
3536
for test_vec in testset:
37+
# 输出当前运行的测试用例坐标,用于测试
3638
print count
3739
count += 1
3840

39-
knn_list = []
41+
knn_list = [] # 当前k个最近邻居
42+
max_index = -1 # 当前k个最近邻居中距离最远点的坐标
43+
max_dist = 0 # 当前k个最近邻居中距离最远点的距离
4044

41-
for i in range(len(train_labels)):
45+
# 先将前k个点放入k个最近邻居中,填充满knn_list
46+
for i in range(k):
4247
label = train_labels[i]
4348
train_vec = trainset[i]
4449

45-
dist = np.linalg.norm(train_vec - test_vec)
50+
dist = np.linalg.norm(train_vec - test_vec) # 计算两个点的欧氏距离
4651

47-
if len(knn_list) < k: # 如果还不够10个邻近点则直接添加即可
48-
knn_list.append((dist,label))
49-
else:
50-
max_index = -1
51-
max_dist = dist
52+
knn_list.append((dist,label))
53+
54+
# 剩下的点
55+
for i in range(k,len(train_labels)):
56+
label = train_labels[i]
57+
train_vec = trainset[i]
58+
59+
dist = np.linalg.norm(train_vec - test_vec) # 计算两个点的欧氏距离
5260

53-
# 寻找10个邻近点钟距离最远的点
61+
# 寻找10个邻近点钟距离最远的点
62+
if max_index < 0:
5463
for j in range(k):
5564
if max_dist < knn_list[j][0]:
5665
max_index = j
5766
max_dist = knn_list[max_index][0]
5867

59-
if max_index >= 0:
60-
knn_list[max_index] = (dist,label)
68+
# 如果当前k个最近邻居中存在点距离比当前点距离远,则替换
69+
if dist < max_dist:
70+
knn_list[max_index] = (dist,label)
71+
max_index = -1
72+
6173

74+
# 统计选票
6275
class_total = 10
6376
class_count = [0 for i in range(class_total)]
6477
for dist,label in knn_list:
6578
class_count[label] += 1
6679

80+
# 找出最大选票
6781
mmax = max(class_count)
6882

83+
# 找出最大选票标签
6984
for i in range(class_total):
7085
if mmax == class_count[i]:
7186
predict.append(i)

0 commit comments

Comments
 (0)