Skip to content

bpo-40791: Use CRYPTO_memcmp() for compare_digest #20456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Doc/library/hmac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ This module also provides the following helper function:

.. versionadded:: 3.3

.. versionchanged:: 3.10

The function uses OpenSSL's ``CRYPTO_memcmp()`` internally when
available.


.. seealso::

Expand Down
3 changes: 2 additions & 1 deletion Lib/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
"""

import warnings as _warnings
from _operator import _compare_digest as compare_digest
try:
import _hashlib as _hashopenssl
except ImportError:
_hashopenssl = None
_openssl_md_meths = None
from _operator import _compare_digest as compare_digest
else:
_openssl_md_meths = frozenset(_hashopenssl.openssl_md_meth_names)
compare_digest = _hashopenssl.compare_digest
import hashlib as _hashlib

trans_5C = bytes((x ^ 0x5C) for x in range(256))
Expand Down
88 changes: 53 additions & 35 deletions Lib/test/test_hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

from test.support import hashlib_helper

from _operator import _compare_digest as operator_compare_digest

try:
from _hashlib import HMAC as C_HMAC
from _hashlib import hmac_new as c_hmac_new
from _hashlib import compare_digest as openssl_compare_digest
except ImportError:
C_HMAC = None
c_hmac_new = None
openssl_compare_digest = None


def ignore_warning(func):
Expand Down Expand Up @@ -505,110 +509,124 @@ def test_equality_new(self):

class CompareDigestTestCase(unittest.TestCase):

def test_compare_digest(self):
def test_hmac_compare_digest(self):
self._test_compare_digest(hmac.compare_digest)
if openssl_compare_digest is not None:
self.assertIs(hmac.compare_digest, openssl_compare_digest)
else:
self.assertIs(hmac.compare_digest, operator_compare_digest)

def test_operator_compare_digest(self):
self._test_compare_digest(operator_compare_digest)

@unittest.skipIf(openssl_compare_digest is None, "test requires _hashlib")
def test_openssl_compare_digest(self):
self._test_compare_digest(openssl_compare_digest)

def _test_compare_digest(self, compare_digest):
# Testing input type exception handling
a, b = 100, 200
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = 100, b"foobar"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", 200
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = "foobar", b"foobar"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", "foobar"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)

# Testing bytes of different lengths
a, b = b"foobar", b"foo"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing bytes of same lengths, different values
a, b = b"foobar", b"foobaz"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing bytes of same lengths, same values
a, b = b"foobar", b"foobar"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))

# Testing bytearrays of same lengths, same values
a, b = bytearray(b"foobar"), bytearray(b"foobar")
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))

# Testing bytearrays of different lengths
a, b = bytearray(b"foobar"), bytearray(b"foo")
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing bytearrays of same lengths, different values
a, b = bytearray(b"foobar"), bytearray(b"foobaz")
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing byte and bytearray of same lengths, same values
a, b = bytearray(b"foobar"), b"foobar"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(hmac.compare_digest(b, a))
self.assertTrue(compare_digest(a, b))
self.assertTrue(compare_digest(b, a))

# Testing byte bytearray of different lengths
a, b = bytearray(b"foobar"), b"foo"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(hmac.compare_digest(b, a))
self.assertFalse(compare_digest(a, b))
self.assertFalse(compare_digest(b, a))

# Testing byte and bytearray of same lengths, different values
a, b = bytearray(b"foobar"), b"foobaz"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(hmac.compare_digest(b, a))
self.assertFalse(compare_digest(a, b))
self.assertFalse(compare_digest(b, a))

# Testing str of same lengths
a, b = "foobar", "foobar"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))

# Testing str of different lengths
a, b = "foo", "foobar"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing bytes of same lengths, different values
a, b = "foobar", "foobaz"
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

# Testing error cases
a, b = "foobar", b"foobar"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", "foobar"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", 1
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = 100, 200
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)
a, b = "fooä", "fooä"
self.assertRaises(TypeError, hmac.compare_digest, a, b)
self.assertRaises(TypeError, compare_digest, a, b)

# subclasses are supported by ignore __eq__
class mystr(str):
def __eq__(self, other):
return False

a, b = mystr("foobar"), mystr("foobar")
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))
a, b = mystr("foobar"), "foobar"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))
a, b = mystr("foobar"), mystr("foobaz")
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))

class mybytes(bytes):
def __eq__(self, other):
return False

a, b = mybytes(b"foobar"), mybytes(b"foobar")
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))
a, b = mybytes(b"foobar"), b"foobar"
self.assertTrue(hmac.compare_digest(a, b))
self.assertTrue(compare_digest(a, b))
a, b = mybytes(b"foobar"), mybytes(b"foobaz")
self.assertFalse(hmac.compare_digest(a, b))
self.assertFalse(compare_digest(a, b))


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:func:`hashlib.compare_digest` uses OpenSSL's ``CRYPTO_memcmp()`` function
when OpenSSL is available.
116 changes: 116 additions & 0 deletions Modules/_hashopenssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
/* EVP is the preferred interface to hashing in OpenSSL */
#include <openssl/evp.h>
#include <openssl/hmac.h>
#include <openssl/crypto.h>
/* We use the object interface to discover what hashes OpenSSL supports. */
#include <openssl/objects.h>
#include "openssl/err.h"
Expand Down Expand Up @@ -1833,13 +1834,128 @@ _hashlib_get_fips_mode_impl(PyObject *module)
#endif // !LIBRESSL_VERSION_NUMBER


static int
_tscmp(const unsigned char *a, const unsigned char *b,
Py_ssize_t len_a, Py_ssize_t len_b)
{
/* loop count depends on length of b. Might leak very little timing
* information if sizes are different.
*/
Py_ssize_t length = len_b;
const void *left = a;
const void *right = b;
int result = 0;

if (len_a != length) {
left = b;
result = 1;
}

result |= CRYPTO_memcmp(left, right, length);

return (result == 0);
}

/* NOTE: Keep in sync with _operator.c implementation. */

/*[clinic input]
_hashlib.compare_digest

a: object
b: object
/

Return 'a == b'.

This function uses an approach designed to prevent
timing analysis, making it appropriate for cryptography.

a and b must both be of the same type: either str (ASCII only),
or any bytes-like object.

Note: If a and b are of different lengths, or if an error occurs,
a timing attack could theoretically reveal information about the
types and lengths of a and b--but not their values.
[clinic start generated code]*/

static PyObject *
_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b)
/*[clinic end generated code: output=6f1c13927480aed9 input=9c40c6e566ca12f5]*/
{
int rc;

/* ASCII unicode string */
if(PyUnicode_Check(a) && PyUnicode_Check(b)) {
if (PyUnicode_READY(a) == -1 || PyUnicode_READY(b) == -1) {
return NULL;
}
if (!PyUnicode_IS_ASCII(a) || !PyUnicode_IS_ASCII(b)) {
PyErr_SetString(PyExc_TypeError,
"comparing strings with non-ASCII characters is "
"not supported");
return NULL;
}

rc = _tscmp(PyUnicode_DATA(a),
PyUnicode_DATA(b),
PyUnicode_GET_LENGTH(a),
PyUnicode_GET_LENGTH(b));
}
/* fallback to buffer interface for bytes, bytesarray and other */
else {
Py_buffer view_a;
Py_buffer view_b;

if (PyObject_CheckBuffer(a) == 0 && PyObject_CheckBuffer(b) == 0) {
PyErr_Format(PyExc_TypeError,
"unsupported operand types(s) or combination of types: "
"'%.100s' and '%.100s'",
Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name);
return NULL;
}

if (PyObject_GetBuffer(a, &view_a, PyBUF_SIMPLE) == -1) {
return NULL;
}
if (view_a.ndim > 1) {
PyErr_SetString(PyExc_BufferError,
"Buffer must be single dimension");
PyBuffer_Release(&view_a);
return NULL;
}

if (PyObject_GetBuffer(b, &view_b, PyBUF_SIMPLE) == -1) {
PyBuffer_Release(&view_a);
return NULL;
}
if (view_b.ndim > 1) {
PyErr_SetString(PyExc_BufferError,
"Buffer must be single dimension");
PyBuffer_Release(&view_a);
PyBuffer_Release(&view_b);
return NULL;
}

rc = _tscmp((const unsigned char*)view_a.buf,
(const unsigned char*)view_b.buf,
view_a.len,
view_b.len);

PyBuffer_Release(&view_a);
PyBuffer_Release(&view_b);
}

return PyBool_FromLong(rc);
}

/* List of functions exported by this module */

static struct PyMethodDef EVP_functions[] = {
EVP_NEW_METHODDEF
PBKDF2_HMAC_METHODDEF
_HASHLIB_SCRYPT_METHODDEF
_HASHLIB_GET_FIPS_MODE_METHODDEF
_HASHLIB_COMPARE_DIGEST_METHODDEF
_HASHLIB_HMAC_SINGLESHOT_METHODDEF
_HASHLIB_HMAC_NEW_METHODDEF
_HASHLIB_OPENSSL_MD5_METHODDEF
Expand Down
2 changes: 2 additions & 0 deletions Modules/_operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,8 @@ _operator_length_hint_impl(PyObject *module, PyObject *obj,
return PyObject_LengthHint(obj, default_value);
}

/* NOTE: Keep in sync with _hashopenssl.c implementation. */

/*[clinic input]
_operator._compare_digest = _operator.eq

Expand Down
Loading