Skip to content

Commit 8f8c8b1

Browse files
committed
Expose face_distance and prevent errors with empty inputs
1 parent b7b8b9f commit 8f8c8b1

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

face_recognition/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
__email__ = 'ageitgey@gmail.com'
55
__version__ = '0.1.0'
66

7-
from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces
7+
from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces, face_distance

face_recognition/api.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,19 @@ def _trim_css_to_bounds(css, image_shape):
5252
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)
5353

5454

55-
def _face_distance(faces, face_to_compare):
55+
def face_distance(face_encodings, face_to_compare):
5656
"""
57-
Given a list of face encodings, compared them to a known face encoding and get a euclidean distance
58-
for each comparison face.
57+
Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
58+
for each comparison face. The distance tells you how similar the faces are.
5959
6060
:param faces: List of face encodings to compare
6161
:param face_to_compare: A face encoding to compare against
62-
:return: A list with the distance for each face in the same order as the 'faces' array
62+
:return: A numpy ndarray with the distance for each face in the same order as the 'faces' array
6363
"""
64-
return np.linalg.norm(faces - face_to_compare, axis=1)
64+
if len(face_encodings) == 0:
65+
return np.empty((0))
66+
67+
return np.linalg.norm(face_encodings - face_to_compare, axis=1)
6568

6669

6770
def load_image_file(filename, mode='RGB'):
@@ -154,4 +157,4 @@ def compare_faces(known_face_encodings, face_encoding_to_check, tolerance=0.6):
154157
:param tolerance: How much distance between faces to consider it a match. Lower is more strict. 0.6 is typical best performance.
155158
:return: A list of True/False values indicating which known_face_encodings match the face encoding to check
156159
"""
157-
return list(_face_distance(known_face_encodings, face_encoding_to_check) <= tolerance)
160+
return list(face_distance(known_face_encodings, face_encoding_to_check) <= tolerance)

tests/test_face_recognition.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import numpy as np
1415
from click.testing import CliRunner
1516

1617
from face_recognition import api
@@ -94,6 +95,50 @@ def test_face_encodings(self):
9495
self.assertEqual(len(encodings), 1)
9596
self.assertEqual(len(encodings[0]), 128)
9697

98+
def test_face_distance(self):
99+
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
100+
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
101+
img_a3 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama3.jpg'))
102+
103+
img_b1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
104+
105+
face_encoding_a1 = api.face_encodings(img_a1)[0]
106+
face_encoding_a2 = api.face_encodings(img_a2)[0]
107+
face_encoding_a3 = api.face_encodings(img_a3)[0]
108+
face_encoding_b1 = api.face_encodings(img_b1)[0]
109+
110+
faces_to_compare = [
111+
face_encoding_a2,
112+
face_encoding_a3,
113+
face_encoding_b1]
114+
115+
distance_results = api.face_distance(faces_to_compare, face_encoding_a1)
116+
117+
# 0.6 is the default face distance match threshold. So we'll spot-check that the numbers returned
118+
# are above or below that based on if they should match (since the exact numbers could vary).
119+
self.assertEqual(type(distance_results), np.ndarray)
120+
self.assertLessEqual(distance_results[0], 0.6)
121+
self.assertLessEqual(distance_results[1], 0.6)
122+
self.assertGreater(distance_results[2], 0.6)
123+
124+
def test_face_distance_empty_lists(self):
125+
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
126+
face_encoding = api.face_encodings(img)[0]
127+
128+
# empty python list
129+
faces_to_compare = []
130+
131+
distance_results = api.face_distance(faces_to_compare, face_encoding)
132+
self.assertEqual(type(distance_results), np.ndarray)
133+
self.assertEqual(len(distance_results), 0)
134+
135+
# empty numpy list
136+
faces_to_compare = np.array([])
137+
138+
distance_results = api.face_distance(faces_to_compare, face_encoding)
139+
self.assertEqual(type(distance_results), np.ndarray)
140+
self.assertEqual(len(distance_results), 0)
141+
97142
def test_compare_faces(self):
98143
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
99144
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
@@ -112,10 +157,30 @@ def test_compare_faces(self):
112157
face_encoding_b1]
113158

114159
match_results = api.compare_faces(faces_to_compare, face_encoding_a1)
160+
161+
self.assertEqual(type(match_results), list)
115162
self.assertTrue(match_results[0])
116163
self.assertTrue(match_results[1])
117164
self.assertFalse(match_results[2])
118165

166+
def test_compare_faces_empty_lists(self):
167+
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
168+
face_encoding = api.face_encodings(img)[0]
169+
170+
# empty python list
171+
faces_to_compare = []
172+
173+
match_results = api.compare_faces(faces_to_compare, face_encoding)
174+
self.assertEqual(type(match_results), list)
175+
self.assertListEqual(match_results, [])
176+
177+
# empty numpy list
178+
faces_to_compare = np.array([])
179+
180+
match_results = api.compare_faces(faces_to_compare, face_encoding)
181+
self.assertEqual(type(match_results), list)
182+
self.assertListEqual(match_results, [])
183+
119184
def test_command_line_interface(self):
120185
target_string = '--help Show this message and exit.'
121186
runner = CliRunner()

0 commit comments

Comments
 (0)