diff --git a/pymysql/connections.py b/pymysql/connections.py index 75e07f34..b4dc7ad7 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -152,6 +152,12 @@ class Connection(object): (default: 10, min: 1, max: 31536000) :param ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. + :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate + :param ssl_cert: Path to the file that contains a PEM-formatted client certificate + :param ssl_disabled: A boolean value that disables usage of TLS + :param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate + :param ssl_verify_cert: Set to true to check the validity of server certificates + :param ssl_verify_identity: Set to true to check the server's identity :param read_default_group: Group to read from in the configuration file. :param compress: Not supported :param named_pipe: Not supported @@ -191,7 +197,9 @@ def __init__(self, host=None, user=None, password="", max_allowed_packet=16*1024*1024, defer_connect=False, auth_plugin_map=None, read_timeout=None, write_timeout=None, bind_address=None, binary_prefix=False, program_name=None, - server_public_key=None): + server_public_key=None, ssl_ca=None, ssl_cert=None, + ssl_disabled=None, ssl_key=None, ssl_verify_cert=None, + ssl_verify_identity=None): if use_unicode is None and sys.version_info[0] > 2: use_unicode = True @@ -245,12 +253,23 @@ def _config(key, arg): ssl[key] = value self.ssl = False - if ssl: - if not SSL_ENABLED: - raise NotImplementedError("ssl module not found") - self.ssl = True - client_flag |= CLIENT.SSL - self.ctx = self._create_ssl_ctx(ssl) + if not ssl_disabled: + if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: + ssl = { + "ca": ssl_ca, + "check_hostname": bool(ssl_verify_identity), + "verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False, + } + if ssl_cert is not None: + ssl["cert"] = ssl_cert + if ssl_key is not None: + ssl["key" ] = ssl_key + if ssl: + if not SSL_ENABLED: + raise NotImplementedError("ssl module not found") + self.ssl = True + client_flag |= CLIENT.SSL + self.ctx = self._create_ssl_ctx(ssl) self.host = host or "localhost" self.port = port or 3306 @@ -334,7 +353,22 @@ def _create_ssl_ctx(self, sslp): hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True) - ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + verify_mode_value = sslp.get('verify_mode') + if verify_mode_value is None: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + elif isinstance(verify_mode_value, bool): + ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE + else: + if isinstance(verify_mode_value, (text_type, str_type)): + verify_mode_value = verify_mode_value.lower() + if verify_mode_value in ("none", "0", "false", "no"): + ctx.verify_mode = ssl.CERT_NONE + elif verify_mode_value == "optional": + ctx.verify_mode = ssl.CERT_OPTIONAL + elif verify_mode_value in ("required", "1", "true", "yes"): + ctx.verify_mode = ssl.CERT_REQUIRED + else: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if 'cert' in sslp: ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key')) if 'cipher' in sslp: diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index e4d24c44..966b2696 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -1,14 +1,14 @@ import datetime +import ssl import sys import time +import mock import pytest import pymysql from pymysql.tests import base from pymysql._compat import text_type from pymysql.constants import CLIENT -import pytest - class TempUser: def __init__(self, c, user, db, auth=None, authdata=None, password=None): @@ -478,6 +478,162 @@ def test_defer_connect(self): c.close() sock.close() + def test_ssl_connect(self): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + "cipher": "cipher", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_called_with("cipher") + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_not_called + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (True, "1", "yes", "true"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (None, False, "0", "no", "false"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_ca in ("ca", None): + for ssl_verify_cert in ("foo", "bar", ""): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca=ssl_ca, + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == (ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE), (ssl_ca, ssl_verify_cert) + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ssl_verify_identity=True, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert not create_default_context.called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert not create_default_context.called + # A custom type and function to escape it class Foo(object): diff --git a/requirements-dev.txt b/requirements-dev.txt index d65512fb..69d3f68a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ cryptography PyNaCl>=1.4.0 pytest +mock