Skip to content

Commit 20d50f3

Browse files
committed
Don't choke on images that use 32 bits per channel. Downsample them.
1 parent 961e2c6 commit 20d50f3

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

face_recognition/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ def _face_distance(faces, face_to_compare):
4444
return np.array([np.linalg.norm(face - face_to_compare) for face in faces])
4545

4646

47-
def load_image_file(filename):
47+
def load_image_file(filename, mode='RGB'):
4848
"""
4949
Loads an image file (.jpg, .png, etc) into a numpy array
5050
5151
:param filename: image file to load
52+
:param mode: format to convert the image to. Only 'RGB' (8-bit RGB, 3 channels) and 'L' (black and white) are supported.
5253
:return: image contents as numpy array
5354
"""
54-
return scipy.misc.imread(filename)
55+
return scipy.misc.imread(filename, mode=mode)
5556

5657

5758
def _raw_face_locations(img, number_of_times_to_upsample=1):

tests/test_face_recognition.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ class Test_face_recognition(unittest.TestCase):
2121

2222
def test_load_image_file(self):
2323
img = api.load_image_file(os.path.join(os.path.dirname(__file__), "test_images", "obama.jpg"))
24-
2524
assert img.shape == (1137, 910, 3)
2625

26+
def test_load_image_file_32bit(self):
27+
img = api.load_image_file(os.path.join(os.path.dirname(__file__), "test_images", "32bit.png"))
28+
assert img.shape == (1200, 626, 3)
29+
2730
def test_raw_face_locations(self):
2831
img = api.load_image_file(os.path.join(os.path.dirname(__file__), "test_images", "obama.jpg"))
2932
detected_faces = api._raw_face_locations(img)
@@ -32,6 +35,14 @@ def test_raw_face_locations(self):
3235
assert detected_faces[0].top() == 142
3336
assert detected_faces[0].bottom() == 409
3437

38+
def test_raw_face_locations_32bit_image(self):
39+
img = api.load_image_file(os.path.join(os.path.dirname(__file__), "test_images", "32bit.png"))
40+
detected_faces = api._raw_face_locations(img)
41+
42+
assert len(detected_faces) == 1
43+
assert detected_faces[0].top() == 290
44+
assert detected_faces[0].bottom() == 558
45+
3546
def test_face_locations(self):
3647
img = api.load_image_file(os.path.join(os.path.dirname(__file__), "test_images", "obama.jpg"))
3748
detected_faces = api.face_locations(img)

tests/test_images/32bit.png

867 KB
Loading

0 commit comments

Comments
 (0)