From 4ad58c1ac1d507139158603745c8acdfe3da6773 Mon Sep 17 00:00:00 2001 From: Moriyoshi Koizumi Date: Mon, 2 Nov 2020 12:43:11 +0900 Subject: [PATCH 1/2] Add connector-python compatible options. Also fixes #842. https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html --- pymysql/connections.py | 48 ++++++++-- pymysql/tests/test_connection.py | 148 ++++++++++++++++++++++++++++++- requirements-dev.txt | 1 + 3 files changed, 187 insertions(+), 10 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 75e07f34..b6bd9867 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -152,6 +152,12 @@ class Connection(object): (default: 10, min: 1, max: 31536000) :param ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. + :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate + :param ssl_cert: Path to the file that contains a PEM-formatted client certificate + :param ssl_disabled: A boolean value that disables usage of TLS + :param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate + :param ssl_verify_cert: Set to true to check the validity of server certificates + :param ssl_verify_identity: Set to true to check the server's identity :param read_default_group: Group to read from in the configuration file. :param compress: Not supported :param named_pipe: Not supported @@ -191,7 +197,9 @@ def __init__(self, host=None, user=None, password="", max_allowed_packet=16*1024*1024, defer_connect=False, auth_plugin_map=None, read_timeout=None, write_timeout=None, bind_address=None, binary_prefix=False, program_name=None, - server_public_key=None): + server_public_key=None, ssl_ca=None, ssl_cert=None, + ssl_disabled=None, ssl_key=None, ssl_verify_cert=None, + ssl_verify_identity=None): if use_unicode is None and sys.version_info[0] > 2: use_unicode = True @@ -245,12 +253,21 @@ def _config(key, arg): ssl[key] = value self.ssl = False - if ssl: - if not SSL_ENABLED: - raise NotImplementedError("ssl module not found") - self.ssl = True - client_flag |= CLIENT.SSL - self.ctx = self._create_ssl_ctx(ssl) + if not ssl_disabled: + if ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: + ssl = { + "ca": ssl_ca, + "cert": ssl_cert, + "key": ssl_key, + "check_hostname": bool(ssl_verify_identity), + "verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False, + } + if ssl: + if not SSL_ENABLED: + raise NotImplementedError("ssl module not found") + self.ssl = True + client_flag |= CLIENT.SSL + self.ctx = self._create_ssl_ctx(ssl) self.host = host or "localhost" self.port = port or 3306 @@ -334,7 +351,22 @@ def _create_ssl_ctx(self, sslp): hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True) - ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + verify_mode_value = sslp.get('verify_mode') + if verify_mode_value is None: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + elif isinstance(verify_mode_value, bool): + ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE + else: + if isinstance(verify_mode_value, (text_type, str_type)): + verify_mode_value = verify_mode_value.lower() + if verify_mode_value in ("none", "0", "false", "no"): + ctx.verify_mode = ssl.CERT_NONE + elif verify_mode_value == "optional": + ctx.verify_mode = ssl.CERT_OPTIONAL + elif verify_mode_value in ("required", "1", "true", "yes"): + ctx.verify_mode = ssl.CERT_REQUIRED + else: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if 'cert' in sslp: ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key')) if 'cipher' in sslp: diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index e4d24c44..a94fbd61 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -1,14 +1,14 @@ import datetime +import ssl import sys import time +import mock import pytest import pymysql from pymysql.tests import base from pymysql._compat import text_type from pymysql.constants import CLIENT -import pytest - class TempUser: def __init__(self, c, user, db, auth=None, authdata=None, password=None): @@ -478,6 +478,150 @@ def test_defer_connect(self): c.close() sock.close() + def test_ssl_connect(self): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + "cipher": "cipher", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_called_with("cipher") + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (True, "1", "yes", "true"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (None, False, "0", "no", "false"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_ca in ("ca", None): + for ssl_verify_cert in ("foo", "bar", ""): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca=ssl_ca, + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == (ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE), (ssl_ca, ssl_verify_cert) + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + create_default_context.reset() + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ssl_verify_identity=True, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert not create_default_context.called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert not create_default_context.called + # A custom type and function to escape it class Foo(object): diff --git a/requirements-dev.txt b/requirements-dev.txt index d65512fb..69d3f68a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ cryptography PyNaCl>=1.4.0 pytest +mock From 1c4d49c3ab3c0dfac5ae5f8b1b21e33dad200dee Mon Sep 17 00:00:00 2001 From: Moriyoshi Koizumi Date: Mon, 2 Nov 2020 13:32:55 +0900 Subject: [PATCH 2/2] Properly handle the case where only ssl_ca is given. --- pymysql/connections.py | 8 +++++--- pymysql/tests/test_connection.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index b6bd9867..b4dc7ad7 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -254,14 +254,16 @@ def _config(key, arg): self.ssl = False if not ssl_disabled: - if ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: + if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: ssl = { "ca": ssl_ca, - "cert": ssl_cert, - "key": ssl_key, "check_hostname": bool(ssl_verify_identity), "verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False, } + if ssl_cert is not None: + ssl["cert"] = ssl_cert + if ssl_key is not None: + ssl["key" ] = ssl_key if ssl: if not SSL_ENABLED: raise NotImplementedError("ssl module not found") diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index a94fbd61..966b2696 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -514,6 +514,19 @@ def test_ssl_connect(self): dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") dummy_ssl_context.set_ciphers.assert_not_called + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect") as connect, \ + mock.patch("pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_not_called + dummy_ssl_context.set_ciphers.assert_not_called + dummy_ssl_context = mock.Mock(options=0) with mock.patch("pymysql.connections.Connection.connect") as connect, \ mock.patch("pymysql.connections.ssl.create_default_context", @@ -583,7 +596,6 @@ def test_ssl_connect(self): with mock.patch("pymysql.connections.Connection.connect") as connect, \ mock.patch("pymysql.connections.ssl.create_default_context", new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: - create_default_context.reset() pymysql.connect( ssl_ca="ca", ssl_cert="cert",