|
31 | 31 |
|
32 | 32 | import sys
|
33 | 33 | import socket
|
| 34 | +import logging |
34 | 35 |
|
35 | 36 | from functools import wraps
|
36 | 37 |
|
37 | 38 | from .authentication import MySQL41AuthPlugin
|
38 | 39 | from .errors import InterfaceError, OperationalError, ProgrammingError
|
39 | 40 | from .compat import PY3, STRING_TYPES, UNICODE_TYPES
|
40 | 41 | from .crud import Schema
|
| 42 | +from .constants import SSLMode |
41 | 43 | from .protocol import Protocol, MessageReaderWriter
|
42 | 44 | from .result import Result, RowResult, DocResult
|
43 | 45 | from .statement import SqlStatement, AddStatement
|
44 | 46 |
|
45 | 47 |
|
46 | 48 | _DROP_DATABASE_QUERY = "DROP DATABASE IF EXISTS `{0}`"
|
47 | 49 | _CREATE_DATABASE_QUERY = "CREATE DATABASE IF NOT EXISTS `{0}`"
|
48 |
| - |
| 50 | +_LOGGER = logging.getLogger("mysqlx") |
49 | 51 |
|
50 | 52 | class SocketStream(object):
|
51 | 53 | def __init__(self):
|
52 | 54 | self._socket = None
|
53 | 55 | self._is_ssl = False
|
| 56 | + self._host = None |
54 | 57 |
|
55 | 58 | def connect(self, params):
|
56 | 59 | if isinstance(params, tuple):
|
| 60 | + self._host = params[0] |
57 | 61 | s_type = socket.AF_INET6 if ":" in params[0] else socket.AF_INET
|
58 | 62 | else:
|
59 | 63 | s_type = socket.AF_UNIX
|
@@ -84,39 +88,46 @@ def close(self):
|
84 | 88 | self._socket.close()
|
85 | 89 | self._socket = None
|
86 | 90 |
|
87 |
| - def set_ssl(self, ssl_opts={}): |
| 91 | + def set_ssl(self, ssl_mode, ssl_ca, ssl_crl, ssl_cert, ssl_key): |
88 | 92 | if not SSL_AVAILABLE:
|
89 | 93 | self.close()
|
90 | 94 | raise RuntimeError("Python installation has no SSL support.")
|
91 | 95 |
|
92 | 96 | context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
93 | 97 | context.load_default_certs()
|
94 |
| - if "ssl-ca" in ssl_opts: |
| 98 | + |
| 99 | + if ssl_ca: |
95 | 100 | try:
|
96 |
| - context.load_verify_locations(ssl_opts["ssl-ca"]) |
| 101 | + context.load_verify_locations(ssl_ca) |
97 | 102 | context.verify_mode = ssl.CERT_REQUIRED
|
98 |
| - except (IOError, ssl.SSLError): |
| 103 | + except (IOError, ssl.SSLError) as err: |
99 | 104 | self.close()
|
100 |
| - raise InterfaceError("Invalid CA certificate.") |
101 |
| - if "ssl-crl" in ssl_opts: |
| 105 | + raise InterfaceError("Invalid CA Certificate: {}".format(err)) |
| 106 | + |
| 107 | + if ssl_crl: |
102 | 108 | try:
|
103 |
| - context.load_verify_locations(ssl_opts["ssl-crl"]) |
104 |
| - context.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN |
105 |
| - except (IOError, ssl.SSLError): |
| 109 | + context.load_verify_locations(ssl_crl) |
| 110 | + context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF |
| 111 | + except (IOError, ssl.SSLError) as err: |
106 | 112 | self.close()
|
107 |
| - raise InterfaceError("Invalid CRL.") |
108 |
| - if "ssl-cert" in ssl_opts: |
| 113 | + raise InterfaceError("Invalid CRL: {}".format(err)) |
| 114 | + |
| 115 | + if ssl_cert: |
109 | 116 | try:
|
110 |
| - context.load_cert_chain(ssl_opts["ssl-cert"], |
111 |
| - ssl_opts.get("ssl-key", None)) |
112 |
| - except (IOError, ssl.SSLError): |
| 117 | + context.load_cert_chain(ssl_cert, ssl_key) |
| 118 | + except (IOError, ssl.SSLError) as err: |
113 | 119 | self.close()
|
114 |
| - raise InterfaceError("Invalid Client Certificate/Key.") |
115 |
| - elif "ssl-key" in ssl_opts: |
116 |
| - self.close() |
117 |
| - raise InterfaceError("Client Certificate not provided.") |
| 120 | + raise InterfaceError("Invalid Certificate/Key: {}".format(err)) |
118 | 121 |
|
119 | 122 | self._socket = context.wrap_socket(self._socket)
|
| 123 | + if ssl_mode == SSLMode.VERIFY_IDENTITY: |
| 124 | + try: |
| 125 | + hostname = socket.gethostbyaddr(self._host) |
| 126 | + ssl.match_hostname(self._socket.getpeercert(), hostname[0]) |
| 127 | + except ssl.CertificateError as err: |
| 128 | + self.close() |
| 129 | + raise InterfaceError("Unable to verify server identity: {}" |
| 130 | + "".format(err)) |
120 | 131 | self._is_ssl = True
|
121 | 132 |
|
122 | 133 |
|
@@ -223,22 +234,29 @@ def connect(self):
|
223 | 234 | raise InterfaceError("Failed to connect to any of the routers.", 4001)
|
224 | 235 |
|
225 | 236 | def _handle_capabilities(self):
|
| 237 | + if self.settings.get("ssl-mode") == SSLMode.DISABLED: |
| 238 | + return |
| 239 | + if "socket" in self.settings: |
| 240 | + if self.settings.get("ssl-mode"): |
| 241 | + _LOGGER.warning("SSL not required when using Unix socket.") |
| 242 | + return |
| 243 | + |
226 | 244 | data = self.protocol.get_capabilites().capabilities
|
227 | 245 | if not (data[0]["name"].lower() == "tls" if data else False):
|
228 |
| - if self.settings.get("ssl-enable", False): |
229 |
| - self.close() |
230 |
| - raise OperationalError("SSL not enabled at server.") |
231 |
| - return |
| 246 | + self.close() |
| 247 | + raise OperationalError("SSL not enabled at server.") |
232 | 248 |
|
233 | 249 | if sys.version_info < (2, 7, 9):
|
234 |
| - if self.settings.get("ssl-enable", False): |
235 |
| - self.close() |
236 |
| - raise RuntimeError("The support for SSL is not available for " |
237 |
| - "this Python version.") |
238 |
| - return |
| 250 | + self.close() |
| 251 | + raise RuntimeError("The support for SSL is not available for " |
| 252 | + "this Python version.") |
239 | 253 |
|
240 | 254 | self.protocol.set_capabilities(tls=True)
|
241 |
| - self.stream.set_ssl(self.settings) |
| 255 | + self.stream.set_ssl(self.settings.get("ssl-mode", SSLMode.REQUIRED), |
| 256 | + self.settings.get("ssl-ca"), |
| 257 | + self.settings.get("ssl-crl"), |
| 258 | + self.settings.get("ssl-cert"), |
| 259 | + self.settings.get("ssl-key")) |
242 | 260 |
|
243 | 261 | def _authenticate(self):
|
244 | 262 | plugin = MySQL41AuthPlugin(self._user, self._password)
|
|
0 commit comments