Skip to content

Add MySQL Connector/Python compatible SSL options. #903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ class Connection(object):
(default: 10, min: 1, max: 31536000)
:param ssl:
A dict of arguments similar to mysql_ssl_set()'s parameters.
:param ssl_ca: Path to the file that contains a PEM-formatted CA certificate
:param ssl_cert: Path to the file that contains a PEM-formatted client certificate
:param ssl_disabled: A boolean value that disables usage of TLS
:param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate
:param ssl_verify_cert: Set to true to check the validity of server certificates
:param ssl_verify_identity: Set to true to check the server's identity
:param read_default_group: Group to read from in the configuration file.
:param compress: Not supported
:param named_pipe: Not supported
Expand Down Expand Up @@ -191,7 +197,9 @@ def __init__(self, host=None, user=None, password="",
max_allowed_packet=16*1024*1024, defer_connect=False,
auth_plugin_map=None, read_timeout=None, write_timeout=None,
bind_address=None, binary_prefix=False, program_name=None,
server_public_key=None):
server_public_key=None, ssl_ca=None, ssl_cert=None,
ssl_disabled=None, ssl_key=None, ssl_verify_cert=None,
ssl_verify_identity=None):
if use_unicode is None and sys.version_info[0] > 2:
use_unicode = True

Expand Down Expand Up @@ -245,12 +253,23 @@ def _config(key, arg):
ssl[key] = value

self.ssl = False
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
self.ssl = True
client_flag |= CLIENT.SSL
self.ctx = self._create_ssl_ctx(ssl)
if not ssl_disabled:
if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
ssl = {
"ca": ssl_ca,
"check_hostname": bool(ssl_verify_identity),
"verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False,
}
if ssl_cert is not None:
ssl["cert"] = ssl_cert
if ssl_key is not None:
ssl["key" ] = ssl_key
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
self.ssl = True
client_flag |= CLIENT.SSL
self.ctx = self._create_ssl_ctx(ssl)

self.host = host or "localhost"
self.port = port or 3306
Expand Down Expand Up @@ -334,7 +353,22 @@ def _create_ssl_ctx(self, sslp):
hasnoca = ca is None and capath is None
ctx = ssl.create_default_context(cafile=ca, capath=capath)
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
verify_mode_value = sslp.get('verify_mode')
if verify_mode_value is None:
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
elif isinstance(verify_mode_value, bool):
ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
else:
if isinstance(verify_mode_value, (text_type, str_type)):
verify_mode_value = verify_mode_value.lower()
if verify_mode_value in ("none", "0", "false", "no"):
ctx.verify_mode = ssl.CERT_NONE
elif verify_mode_value == "optional":
ctx.verify_mode = ssl.CERT_OPTIONAL
elif verify_mode_value in ("required", "1", "true", "yes"):
ctx.verify_mode = ssl.CERT_REQUIRED
else:
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
if 'cert' in sslp:
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
if 'cipher' in sslp:
Expand Down
160 changes: 158 additions & 2 deletions pymysql/tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
import ssl
import sys
import time
import mock
import pytest
import pymysql
from pymysql.tests import base
from pymysql._compat import text_type
from pymysql.constants import CLIENT

import pytest


class TempUser:
def __init__(self, c, user, db, auth=None, authdata=None, password=None):
Expand Down Expand Up @@ -478,6 +478,162 @@ def test_defer_connect(self):
c.close()
sock.close()

def test_ssl_connect(self):
dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl={
"ca": "ca",
"cert": "cert",
"key": "key",
"cipher": "cipher",
},
)
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_called_with("cipher")

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl={
"ca": "ca",
"cert": "cert",
"key": "key",
},
)
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_ca="ca",
)
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_not_called
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_ca="ca",
ssl_cert="cert",
ssl_key="key",
)
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

for ssl_verify_cert in (True, "1", "yes", "true"):
dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_cert="cert",
ssl_key="key",
ssl_verify_cert=ssl_verify_cert,
)
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

for ssl_verify_cert in (None, False, "0", "no", "false"):
dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_cert="cert",
ssl_key="key",
ssl_verify_cert=ssl_verify_cert,
)
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

for ssl_ca in ("ca", None):
for ssl_verify_cert in ("foo", "bar", ""):
dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_ca=ssl_ca,
ssl_cert="cert",
ssl_key="key",
ssl_verify_cert=ssl_verify_cert,
)
assert create_default_context.called
assert not dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == (ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE), (ssl_ca, ssl_verify_cert)
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_ca="ca",
ssl_cert="cert",
ssl_key="key",
ssl_verify_identity=True,
)
assert create_default_context.called
assert dummy_ssl_context.check_hostname
assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key")
dummy_ssl_context.set_ciphers.assert_not_called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_disabled=True,
ssl={
"ca": "ca",
"cert": "cert",
"key": "key",
},
)
assert not create_default_context.called

dummy_ssl_context = mock.Mock(options=0)
with mock.patch("pymysql.connections.Connection.connect") as connect, \
mock.patch("pymysql.connections.ssl.create_default_context",
new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context:
pymysql.connect(
ssl_disabled=True,
ssl_ca="ca",
ssl_cert="cert",
ssl_key="key",
)
assert not create_default_context.called


# A custom type and function to escape it
class Foo(object):
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cryptography
PyNaCl>=1.4.0
pytest
mock