From b799bd9f3527e8c33c154c989301dae6f2a2cf70 Mon Sep 17 00:00:00 2001 From: lishuode Date: Thu, 22 Jun 2017 07:28:15 +0000 Subject: [PATCH 01/20] Add support for SHA256 auth plugin --- pymysql/auth/__init__.py | 0 pymysql/auth/sha256_password_plugin.py | 58 ++++++++++++++++++++++++++ pymysql/connections.py | 54 +++++++++++++++++------- pymysql/protocol.py | 4 ++ pymysql/tests/test_connection.py | 5 ++- 5 files changed, 103 insertions(+), 18 deletions(-) create mode 100644 pymysql/auth/__init__.py create mode 100644 pymysql/auth/sha256_password_plugin.py diff --git a/pymysql/auth/__init__.py b/pymysql/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py new file mode 100644 index 00000000..539b9249 --- /dev/null +++ b/pymysql/auth/sha256_password_plugin.py @@ -0,0 +1,58 @@ +from ..constants import CLIENT +from ..err import OperationalError + +# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm +# which is needed when use sha256_password_plugin with no SSL +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization, hashes + from cryptography.hazmat.primitives.asymmetric import padding + HAVE_CRYPTOGRAPHY = True +except ImportError: + HAVE_CRYPTOGRAPHY = False + + +def _xor_password(password, salt): + password_bytes = bytearray(password, 'ascii') + salt_len = len(salt) + for i in range(len(password_bytes)): + password_bytes[i] ^= ord(salt[i % salt_len]) + return password_bytes + + +def _sha256_rsa_crypt(password, salt, public_key): + if not HAVE_CRYPTOGRAPHY: + raise OperationalError("cryptography module not found for sha256_password_plugin") + message = _xor_password(password + b'\0', salt) + rsa_key = serialization.load_pem_public_key(public_key, default_backend()) + return rsa_key.encrypt( + message.decode('latin1').encode('latin1'), padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA1()), + algorithm=hashes.SHA1(), + label=None)) + + +class SHA256PasswordPlugin(object): + def __init__(self, con): + self.con = con + + def authenticate(self, pkt): + if self.con.ssl and self.con.server_capabilities & CLIENT.SSL: + data = self.con.password.encode('latin1') + b'\0' + else: + if pkt.is_auth_switch_request(): + self.con.salt = pkt.read_all() + if self.con.server_public_key == '': + self.con.write_packet(b'\1') + pkt = self.con._read_packet() + if pkt.is_extra_auth_data() and self.con.server_public_key == '': + pkt.read_uint8() + self.con.server_public_key = pkt.read_all() + data = _sha256_rsa_crypt( + self.con.password, + self.con.salt, + self.con.server_public_key) + self.con.write_packet(data) + pkt = self.con._read_packet() + pkt.check_error() + return pkt diff --git a/pymysql/connections.py b/pymysql/connections.py index 53e18e3c..0b406e55 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -16,6 +16,7 @@ import traceback import warnings +from .auth import sha256_password_plugin as _auth from .charset import charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS from . import converters @@ -43,7 +44,6 @@ # KeyError occurs when there's no entry in OS database for a current user. DEFAULT_USER = None - DEBUG = False _py_version = sys.version_info[:2] @@ -107,7 +107,6 @@ def _scramble(password, message): result = s.digest() return _my_crypt(result, stage1) - def _my_crypt(message1, message2): length = len(message1) result = b'' @@ -186,6 +185,7 @@ def lenenc_int(i): else: raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64))) + class Connection(object): """ Representation of a socket with a mysql server. @@ -240,6 +240,7 @@ class Connection(object): The class needs an authenticate method taking an authentication packet as an argument. For the dialog plugin, a prompt(echo, prompt) method can be used (if no authenticate method) for returning a string from the user. (experimental) + :param server_public_key: SHA256 authenticaiton plugin public key value. (default: '') :param db: Alias for database. (for compatibility to MySQLdb) :param passwd: Alias for password. (for compatibility to MySQLdb) :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False) @@ -262,7 +263,7 @@ def __init__(self, host=None, user=None, password="", autocommit=False, db=None, passwd=None, local_infile=False, max_allowed_packet=16*1024*1024, defer_connect=False, auth_plugin_map={}, read_timeout=None, write_timeout=None, - bind_address=None, binary_prefix=False): + bind_address=None, binary_prefix=False, server_public_key=''): if no_delay is not None: warnings.warn("no_delay option is deprecated", DeprecationWarning) @@ -379,6 +380,9 @@ def _config(key, arg): self.max_allowed_packet = max_allowed_packet self._auth_plugin_map = auth_plugin_map self._binary_prefix = binary_prefix + if b"sha256_password" not in self._auth_plugin_map: + self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin + self.server_public_key = server_public_key if defer_connect: self._sock = None else: @@ -507,7 +511,7 @@ def select_db(self, db): def escape(self, obj, mapping=None): """Escape whatever value you pass to it. - + Non-standard, for internal use; do not use this in your applications. """ if isinstance(obj, str_type): @@ -521,7 +525,7 @@ def escape(self, obj, mapping=None): def literal(self, obj): """Alias for escape() - + Non-standard, for internal use; do not use this in your applications. """ return self.escape(obj, self.encoders) @@ -861,6 +865,14 @@ def _request_authentication(self): authresp = b'' if self._auth_plugin_name in ('', 'mysql_native_password'): authresp = _scramble(self.password.encode('latin1'), self.salt) + elif self._auth_plugin_name == 'sha256_password': + if self.ssl and self.server_capabilities & CLIENT.SSL: + authresp = self.password.encode('latin1') + b'\0' + else: + if self.password is not None: + authresp = b'\1' + else: + authresp = b'\0' if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += lenenc_int(len(authresp)) + authresp @@ -896,24 +908,20 @@ def _request_authentication(self): data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0' self.write_packet(data) auth_packet = self._read_packet() + elif auth_packet.is_extra_auth_data(): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + handler = self._get_auth_plugin_handler(self._auth_plugin_name) + handler.authenticate(auth_packet) def _process_auth(self, plugin_name, auth_packet): - plugin_class = self._auth_plugin_map.get(plugin_name) - if not plugin_class: - plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) - if plugin_class: + handler = self._get_auth_plugin_handler(plugin_name) + if handler != None: try: - handler = plugin_class(self) return handler.authenticate(auth_packet) except AttributeError: if plugin_name != b'dialog': raise err.OperationalError(2059, "Authentication plugin '%s'" \ " not loaded: - %r missing authenticate method" % (plugin_name, plugin_class)) - except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) - else: - handler = None if plugin_name == b"mysql_native_password": # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) @@ -958,6 +966,20 @@ def _process_auth(self, plugin_name, auth_packet): pkt = self._read_packet() pkt.check_error() return pkt + + def _get_auth_plugin_handler(self, plugin_name): + plugin_class = self._auth_plugin_map.get(plugin_name) + if not plugin_class: + plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) + if plugin_class: + try: + handler = plugin_class(self) + except TypeError: + raise err.OperationalError(2059, "Authentication plugin '%s'" \ + " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) + else: + handler = None + return handler # _mysql support def thread_id(self): @@ -1232,7 +1254,7 @@ def _get_descriptions(self): # This behavior is different from TEXT / BLOB. # We should decode result by connection encoding regardless charsetnr. # See https://github.com/PyMySQL/PyMySQL/issues/488 - encoding = conn_encoding # SELECT CAST(... AS JSON) + encoding = conn_encoding # SELECT CAST(... AS JSON) elif field_type in TEXT_TYPES: if field.charsetnr == 63: # binary # TEXTs with charset=binary means BINARY types. diff --git a/pymysql/protocol.py b/pymysql/protocol.py index e872a0eb..f36e2ae3 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -196,6 +196,10 @@ def is_auth_switch_request(self): # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest return self._data[0:1] == b'\xfe' + def is_extra_auth_data(self): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + return self._data[0:1] == b'\x01' + def is_resultset_packet(self): field_count = ord(self._data[0:1]) return 1 <= field_count <= 250 diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index c626a0d3..28091be2 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -360,11 +360,12 @@ def testAuthSHA256(self): else: c.execute('SET old_passwords = 2') c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')") + c.execute("FLUSH PRIVILEGES") db = self.db.copy() db['password'] = "Sh@256Pa33" - # not implemented yet so thows error + # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_256', **db) + pymysql.connect(user='pymysql_sha256', **db) class TestConnection(base.PyMySQLTestCase): From 56a12f066be305ce4748183578ff1ddce0f8fa2d Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 8 Jun 2018 22:04:48 +0900 Subject: [PATCH 02/20] Add cryptography as dependency --- pymysql/auth/sha256_password_plugin.py | 14 +++----------- setup.py | 3 +++ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py index 539b9249..5d604248 100644 --- a/pymysql/auth/sha256_password_plugin.py +++ b/pymysql/auth/sha256_password_plugin.py @@ -1,15 +1,9 @@ from ..constants import CLIENT from ..err import OperationalError -# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm -# which is needed when use sha256_password_plugin with no SSL -try: - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization, hashes - from cryptography.hazmat.primitives.asymmetric import padding - HAVE_CRYPTOGRAPHY = True -except ImportError: - HAVE_CRYPTOGRAPHY = False +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding def _xor_password(password, salt): @@ -21,8 +15,6 @@ def _xor_password(password, salt): def _sha256_rsa_crypt(password, salt, public_key): - if not HAVE_CRYPTOGRAPHY: - raise OperationalError("cryptography module not found for sha256_password_plugin") message = _xor_password(password + b'\0', salt) rsa_key = serialization.load_pem_public_key(public_key, default_backend()) return rsa_key.encrypt( diff --git a/setup.py b/setup.py index 37342d4b..3d8b50b5 100755 --- a/setup.py +++ b/setup.py @@ -27,6 +27,9 @@ long_description=readme, license="MIT", packages=find_packages(), + install_requires=[ + "cryptography", + ], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Programming Language :: Python :: 2', From 8b7c7c740eb4a457dbea81418ac1882c82c0f965 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 8 Jun 2018 22:36:19 +0900 Subject: [PATCH 03/20] Fix UnicodeError --- pymysql/auth/sha256_password_plugin.py | 8 +++++--- pymysql/connections.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py index 5d604248..34b5477b 100644 --- a/pymysql/auth/sha256_password_plugin.py +++ b/pymysql/auth/sha256_password_plugin.py @@ -1,3 +1,4 @@ +from .._compat import text_type from ..constants import CLIENT from ..err import OperationalError @@ -7,10 +8,11 @@ def _xor_password(password, salt): - password_bytes = bytearray(password, 'ascii') + password_bytes = bytearray(password) + salt = bytearray(salt) # for PY2 compat. salt_len = len(salt) for i in range(len(password_bytes)): - password_bytes[i] ^= ord(salt[i % salt_len]) + password_bytes[i] ^= salt[i % salt_len] return password_bytes @@ -30,7 +32,7 @@ def __init__(self, con): def authenticate(self, pkt): if self.con.ssl and self.con.server_capabilities & CLIENT.SSL: - data = self.con.password.encode('latin1') + b'\0' + data = self.con.password + b'\0' else: if pkt.is_auth_switch_request(): self.con.salt = pkt.read_all() diff --git a/pymysql/connections.py b/pymysql/connections.py index 0b406e55..a5c0c00d 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -331,6 +331,8 @@ def _config(key, arg): self.port = port or 3306 self.user = user or DEFAULT_USER self.password = password or "" + if isinstance(self.password, text_type): + self.password = self.password.encode('ascii') self.db = database self.unix_socket = unix_socket self.bind_address = bind_address From 6a5f4fad78a82b7f1f55b18d6dcc5bc584dceae0 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 21 Jun 2018 22:50:44 +0900 Subject: [PATCH 04/20] wip --- pymysql/auth/sha256_password_plugin.py | 139 +++++++++++++++++++++---- pymysql/connections.py | 65 +++++++----- pymysql/protocol.py | 11 +- sha256_test.py | 26 +++++ 4 files changed, 189 insertions(+), 52 deletions(-) create mode 100644 sha256_test.py diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py index 34b5477b..ee41747e 100644 --- a/pymysql/auth/sha256_password_plugin.py +++ b/pymysql/auth/sha256_password_plugin.py @@ -1,3 +1,6 @@ +""" +Implements sha256_password and caching_sha2_password auth methods. +""" from .._compat import text_type from ..constants import CLIENT from ..err import OperationalError @@ -6,6 +9,18 @@ from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import padding +import hashlib + + +DEBUG = True + + +def _roundtrip(conn, send_data): + conn.write_packet(send_data) + pkt = conn._read_packet() + pkt.check_error() + return pkt + def _xor_password(password, salt): password_bytes = bytearray(password) @@ -26,27 +41,109 @@ def _sha256_rsa_crypt(password, salt, public_key): label=None)) -class SHA256PasswordPlugin(object): - def __init__(self, con): - self.con = con +class SHA256Password(object): + def __init__(self, conn): + self.conn = conn def authenticate(self, pkt): - if self.con.ssl and self.con.server_capabilities & CLIENT.SSL: - data = self.con.password + b'\0' + conn = self.conn + + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: print("sha256: Sending plain password") + data = conn.password + b'\0' + return _roundtrip(conn, data) + + if pkt.is_auth_switch_request(): + conn.salt = pkt.read_all() + if not conn.server_public_key and conn.password: + # Request server public key + if DEBUG: print("sha256: Requesting server public key") + pkt = _roundtrip(conn, b'\1') + + if pkt.is_extra_auth_data(): + conn.server_public_key = pkt._data[1:] + if DEBUG: + print("Received public key:\n", conn.server_public_key.decode('ascii')) + + if conn.password: + if not conn.server_public_key: + raise OperationalError("Couldn't receive server's public key") + data = _sha256_rsa_crypt(conn.password, conn.salt, conn.server_public_key) else: - if pkt.is_auth_switch_request(): - self.con.salt = pkt.read_all() - if self.con.server_public_key == '': - self.con.write_packet(b'\1') - pkt = self.con._read_packet() - if pkt.is_extra_auth_data() and self.con.server_public_key == '': - pkt.read_uint8() - self.con.server_public_key = pkt.read_all() - data = _sha256_rsa_crypt( - self.con.password, - self.con.salt, - self.con.server_public_key) - self.con.write_packet(data) - pkt = self.con._read_packet() - pkt.check_error() - return pkt + data = b'' + + return _roundtrip(conn, data) + + +# XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) +# Used in caching_sha2_password +def _scramble_sha256_password(password, scramble): + if not password: + return b'' + + p1 = hashlib.sha256(password).digest() + p2 = hashlib.sha256(p1).digest() + p3 = hashlib.sha256(p2 + scramble).digest() + + res = bytearray(p1) + for i in range(len(p3)): + res[i] ^= p3[i] + + return bytes(res) + + +class CachingSHA2Password(object): + def __init__(self, conn): + self.conn = conn + + def authenticate(self, pkt): + conn = self.conn + + # No password fast path + if not conn.password: + return _roundtrip(conn, b'') + + if pkt.is_auth_switch_request(): + # Try from fast auth + if DEBUG: print("caching sha2: Trying fast path") + conn.salt = pkt.read_all() + scrambled = _scramble_sha256_password(conn.password, conn.salt) + pkt = _roundtrip(conn, scrambled) + #else: fast auth is tried in initial handshake + + if not pkt.is_extra_auth_data(): + raise OperationalError("caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]) + + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + pkt.advance(1) + n = pkt.read_uint8() + + if n == 3: + if DEBUG: print("caching sha2: succeeded by fast path.") + pkt = conn._read_packet() + pkt.check_error() # pkt must be OK packet + return pkt + + if n != 4: + raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) + + if DEBUG: print("caching sha2: Trying full auth...") + + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: print("caching sha2: Sending plain password via SSL") + return _roundtrip(conn, conn.password + b'\0') + + if not conn.server_public_key: + pkt = _roundtrip(conn, b'\x02') # Request public key + if not pkt.is_extra_auth_data(): + raise OperationalError("caching sha2: Unknown packet for public key: %s" % pkt._data[:1]) + conn.server_public_key = pkt._data[1:] + if DEBUG: + print(conn.server_public_key.decode('ascii')) + + data = _sha256_rsa_crypt(conn.password, conn.salt, conn.server_public_key) + pkt = _roundtrip(conn, data) diff --git a/pymysql/connections.py b/pymysql/connections.py index a5c0c00d..75b3d207 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -240,7 +240,7 @@ class Connection(object): The class needs an authenticate method taking an authentication packet as an argument. For the dialog plugin, a prompt(echo, prompt) method can be used (if no authenticate method) for returning a string from the user. (experimental) - :param server_public_key: SHA256 authenticaiton plugin public key value. (default: '') + :param server_public_key: SHA256 authenticaiton plugin public key value. (default: None) :param db: Alias for database. (for compatibility to MySQLdb) :param passwd: Alias for password. (for compatibility to MySQLdb) :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False) @@ -262,8 +262,8 @@ def __init__(self, host=None, user=None, password="", compress=None, named_pipe=None, no_delay=None, autocommit=False, db=None, passwd=None, local_infile=False, max_allowed_packet=16*1024*1024, defer_connect=False, - auth_plugin_map={}, read_timeout=None, write_timeout=None, - bind_address=None, binary_prefix=False, server_public_key=''): + auth_plugin_map=None, read_timeout=None, write_timeout=None, + bind_address=None, binary_prefix=False, server_public_key=None): if no_delay is not None: warnings.warn("no_delay option is deprecated", DeprecationWarning) @@ -330,9 +330,9 @@ def _config(key, arg): self.host = host or "localhost" self.port = port or 3306 self.user = user or DEFAULT_USER - self.password = password or "" + self.password = password or b"" if isinstance(self.password, text_type): - self.password = self.password.encode('ascii') + self.password = self.password.encode('latin1') self.db = database self.unix_socket = unix_socket self.bind_address = bind_address @@ -380,10 +380,12 @@ def _config(key, arg): self.sql_mode = sql_mode self.init_command = init_command self.max_allowed_packet = max_allowed_packet - self._auth_plugin_map = auth_plugin_map + self._auth_plugin_map = auth_plugin_map or {} self._binary_prefix = binary_prefix - if b"sha256_password" not in self._auth_plugin_map: - self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin + if "caching_sha2_password" not in self._auth_plugin_map: + self._auth_plugin_map["caching_sha2_password"] = _auth.CachingSHA2Password + if "sha256_password" not in self._auth_plugin_map: + self._auth_plugin_map["sha256_password"] = _auth.SHA256Password self.server_public_key = server_public_key if defer_connect: self._sock = None @@ -854,7 +856,7 @@ def _request_authentication(self): if isinstance(self.user, text_type): self.user = self.user.encode(self.encoding) - data_init = struct.pack(' Date: Fri, 22 Jun 2018 16:38:55 +0900 Subject: [PATCH 05/20] Make auth package to _auth module --- pymysql/{auth/sha256_password_plugin.py => _auth.py} | 8 ++++---- pymysql/auth/__init__.py | 0 pymysql/connections.py | 3 ++- sha256_test.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) rename pymysql/{auth/sha256_password_plugin.py => _auth.py} (96%) delete mode 100644 pymysql/auth/__init__.py diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/_auth.py similarity index 96% rename from pymysql/auth/sha256_password_plugin.py rename to pymysql/_auth.py index ee41747e..cd013c1b 100644 --- a/pymysql/auth/sha256_password_plugin.py +++ b/pymysql/_auth.py @@ -1,9 +1,9 @@ """ -Implements sha256_password and caching_sha2_password auth methods. +Implements auth methods """ -from .._compat import text_type -from ..constants import CLIENT -from ..err import OperationalError +from ._compat import text_type +from .constants import CLIENT +from .err import OperationalError from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization, hashes diff --git a/pymysql/auth/__init__.py b/pymysql/auth/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pymysql/connections.py b/pymysql/connections.py index 75b3d207..fdabe08f 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -16,7 +16,8 @@ import traceback import warnings -from .auth import sha256_password_plugin as _auth +from . import _auth + from .charset import charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS from . import converters diff --git a/sha256_test.py b/sha256_test.py index 43e0f433..3e6f0dab 100644 --- a/sha256_test.py +++ b/sha256_test.py @@ -1,7 +1,7 @@ import pymysql pymysql.connections.DEBUG = True -pymysql.auth.sha256_password_plugin.DEBUG = True +pymysql._auth.DEBUG = True host="127.0.0.1" port=3306 From 0f31c0142db20691e792c56d027c22a644e1afa9 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 22 Jun 2018 17:19:56 +0900 Subject: [PATCH 06/20] blaken _auth.py --- pymysql/_auth.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index cd013c1b..dc75113e 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -35,13 +35,17 @@ def _sha256_rsa_crypt(password, salt, public_key): message = _xor_password(password + b'\0', salt) rsa_key = serialization.load_pem_public_key(public_key, default_backend()) return rsa_key.encrypt( - message.decode('latin1').encode('latin1'), padding.OAEP( + message, + padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), - label=None)) + label=None, + ), + ) class SHA256Password(object): + def __init__(self, conn): self.conn = conn @@ -49,7 +53,8 @@ def authenticate(self, pkt): conn = self.conn if conn.ssl and conn.server_capabilities & CLIENT.SSL: - if DEBUG: print("sha256: Sending plain password") + if DEBUG: + print("sha256: Sending plain password") data = conn.password + b'\0' return _roundtrip(conn, data) @@ -57,7 +62,8 @@ def authenticate(self, pkt): conn.salt = pkt.read_all() if not conn.server_public_key and conn.password: # Request server public key - if DEBUG: print("sha256: Requesting server public key") + if DEBUG: + print("sha256: Requesting server public key") pkt = _roundtrip(conn, b'\1') if pkt.is_extra_auth_data(): @@ -68,6 +74,7 @@ def authenticate(self, pkt): if conn.password: if not conn.server_public_key: raise OperationalError("Couldn't receive server's public key") + data = _sha256_rsa_crypt(conn.password, conn.salt, conn.server_public_key) else: data = b'' @@ -77,6 +84,8 @@ def authenticate(self, pkt): # XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) # Used in caching_sha2_password + + def _scramble_sha256_password(password, scramble): if not password: return b'' @@ -93,6 +102,7 @@ def _scramble_sha256_password(password, scramble): class CachingSHA2Password(object): + def __init__(self, conn): self.conn = conn @@ -105,14 +115,17 @@ def authenticate(self, pkt): if pkt.is_auth_switch_request(): # Try from fast auth - if DEBUG: print("caching sha2: Trying fast path") + if DEBUG: + print("caching sha2: Trying fast path") conn.salt = pkt.read_all() scrambled = _scramble_sha256_password(conn.password, conn.salt) pkt = _roundtrip(conn, scrambled) - #else: fast auth is tried in initial handshake + # else: fast auth is tried in initial handshake if not pkt.is_extra_auth_data(): - raise OperationalError("caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]) + raise OperationalError( + "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1] + ) # magic numbers: # 2 - request public key @@ -123,7 +136,8 @@ def authenticate(self, pkt): n = pkt.read_uint8() if n == 3: - if DEBUG: print("caching sha2: succeeded by fast path.") + if DEBUG: + print("caching sha2: succeeded by fast path.") pkt = conn._read_packet() pkt.check_error() # pkt must be OK packet return pkt @@ -131,16 +145,21 @@ def authenticate(self, pkt): if n != 4: raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) - if DEBUG: print("caching sha2: Trying full auth...") + if DEBUG: + print("caching sha2: Trying full auth...") if conn.ssl and conn.server_capabilities & CLIENT.SSL: - if DEBUG: print("caching sha2: Sending plain password via SSL") + if DEBUG: + print("caching sha2: Sending plain password via SSL") return _roundtrip(conn, conn.password + b'\0') if not conn.server_public_key: pkt = _roundtrip(conn, b'\x02') # Request public key if not pkt.is_extra_auth_data(): - raise OperationalError("caching sha2: Unknown packet for public key: %s" % pkt._data[:1]) + raise OperationalError( + "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] + ) + conn.server_public_key = pkt._data[1:] if DEBUG: print(conn.server_public_key.decode('ascii')) From 977172c25b2fdc95bf086c9a61a4584e0d031b3e Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 22 Jun 2018 18:06:39 +0900 Subject: [PATCH 07/20] Update .gitignore --- .gitignore | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index cd93a4b0..0b2c85be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,14 @@ *.pyc *.pyo -__pycache__ -.coverage -/dist -/PyMySQL.egg-info +/.cache +/.coverage +/.idea /.tox +/.venv +/.vscode +/PyMySQL.egg-info /build +/dist +/docs/build /pymysql/tests/databases.json - -/.idea -docs/build +__pycache__ From 52c06217e7f608b8ee8298a5ec225f8c5a78ac93 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 22 Jun 2018 22:41:01 +0900 Subject: [PATCH 08/20] cleanup --- pymysql/_auth.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index dc75113e..9195dfce 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -31,7 +31,11 @@ def _xor_password(password, salt): return password_bytes -def _sha256_rsa_crypt(password, salt, public_key): +def sha2_rsa_encrypt(password, salt, public_key): + """Encrypt password with salt and public_key. + + Used for sha256_password and caching_sha2_password. + """ message = _xor_password(password + b'\0', salt) rsa_key = serialization.load_pem_public_key(public_key, default_backend()) return rsa_key.encrypt( @@ -75,24 +79,25 @@ def authenticate(self, pkt): if not conn.server_public_key: raise OperationalError("Couldn't receive server's public key") - data = _sha256_rsa_crypt(conn.password, conn.salt, conn.server_public_key) + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) else: data = b'' return _roundtrip(conn, data) -# XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) -# Used in caching_sha2_password - +def caching_sha2_scramble(password, nonce): + # (bytes, bytes) -> bytes + """Scramble algorithm used in cached_sha2_password fast path. -def _scramble_sha256_password(password, scramble): + XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce)) + """ if not password: return b'' p1 = hashlib.sha256(password).digest() p2 = hashlib.sha256(p1).digest() - p3 = hashlib.sha256(p2 + scramble).digest() + p3 = hashlib.sha256(p2 + nonce).digest() res = bytearray(p1) for i in range(len(p3)): @@ -118,7 +123,7 @@ def authenticate(self, pkt): if DEBUG: print("caching sha2: Trying fast path") conn.salt = pkt.read_all() - scrambled = _scramble_sha256_password(conn.password, conn.salt) + scrambled = caching_sha2_scramble(conn.password, conn.salt) pkt = _roundtrip(conn, scrambled) # else: fast auth is tried in initial handshake @@ -164,5 +169,5 @@ def authenticate(self, pkt): if DEBUG: print(conn.server_public_key.decode('ascii')) - data = _sha256_rsa_crypt(conn.password, conn.salt, conn.server_public_key) + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) pkt = _roundtrip(conn, data) From 76c7c80fb28b32194588110c50e79de0f56aa955 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 22 Jun 2018 22:43:10 +0900 Subject: [PATCH 09/20] travis: Install cryptography --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 2822cd05..e5dc76c6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,7 +41,7 @@ matrix: # http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version # really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't install: - - pip install -U coveralls unittest2 coverage + - pip install -U coveralls unittest2 coverage cryptography before_script: - ./.travis/initializedb.sh From 832b242baa1724a75e66da48de2840a179be005f Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 22 Jun 2018 22:59:32 +0900 Subject: [PATCH 10/20] Move other auth methods to _auth --- pymysql/_auth.py | 249 +++++++++++++++++++++++++++-------------- pymysql/connections.py | 100 +++-------------- 2 files changed, 179 insertions(+), 170 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index 9195dfce..ecba9d65 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -9,10 +9,103 @@ from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import padding +from functools import partial import hashlib +import struct DEBUG = True +SCRAMBLE_LENGTH = 20 +sha1_new = partial(hashlib.new, 'sha1') + + +# mysql_native_password +# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 + + +def scramble_native_password(password, message): + """Scramble used for mysql_native_password""" + if not password: + return b'' + + stage1 = sha1_new(password).digest() + stage2 = sha1_new(stage1).digest() + s = sha1_new() + s.update(message[:SCRAMBLE_LENGTH]) + s.update(stage2) + result = s.digest() + return _my_crypt(result, stage1) + + +def _my_crypt(message1, message2): + length = len(message1) + result = b'' + for i in range_type(length): + x = ( + struct.unpack('B', message1[i:i + 1])[0] ^ + struct.unpack('B', message2[i:i + 1])[0] + ) + result += struct.pack('B', x) + return result + + +# old_passwords support ported from libmysql/password.c +# https://dev.mysql.com/doc/internals/en/old-password-authentication.html + +SCRAMBLE_LENGTH_323 = 8 + + +class RandStruct_323(object): + + def __init__(self, seed1, seed2): + self.max_value = 0x3FFFFFFF + self.seed1 = seed1 % self.max_value + self.seed2 = seed2 % self.max_value + + def my_rnd(self): + self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value + self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value + return float(self.seed1) / float(self.max_value) + + +def scramble_old_password(password, message): + """Scramble for old_password""" + hash_pass = _hash_password_323(password) + hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) + hash_pass_n = struct.unpack(">LL", hash_pass) + hash_message_n = struct.unpack(">LL", hash_message) + + rand_st = RandStruct_323( + hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1] + ) + outbuf = io.BytesIO() + for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))): + outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) + extra = int2byte(int(rand_st.my_rnd() * 31)) + out = outbuf.getvalue() + outbuf = io.BytesIO() + for c in out: + outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) + return outbuf.getvalue() + + +def _hash_password_323(password): + nr = 1345345333 + add = 7 + nr2 = 0x12345671 + + # x in py3 is numbers, p27 is chars + for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: + nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF + nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF + add = (add + c) & 0xFFFFFFFF + + r1 = nr & ((1 << 31) - 1) # kill sign bits + r2 = nr2 & ((1 << 31) - 1) + return struct.pack(">LL", r1, r2) + + +# sha256_password def _roundtrip(conn, send_data): @@ -48,42 +141,35 @@ def sha2_rsa_encrypt(password, salt, public_key): ) -class SHA256Password(object): - - def __init__(self, conn): - self.conn = conn - - def authenticate(self, pkt): - conn = self.conn +def sha256_password_auth(conn, pkt): + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: + print("sha256: Sending plain password") + data = conn.password + b'\0' + return _roundtrip(conn, data) - if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if pkt.is_auth_switch_request(): + conn.salt = pkt.read_all() + if not conn.server_public_key and conn.password: + # Request server public key if DEBUG: - print("sha256: Sending plain password") - data = conn.password + b'\0' - return _roundtrip(conn, data) - - if pkt.is_auth_switch_request(): - conn.salt = pkt.read_all() - if not conn.server_public_key and conn.password: - # Request server public key - if DEBUG: - print("sha256: Requesting server public key") - pkt = _roundtrip(conn, b'\1') - - if pkt.is_extra_auth_data(): - conn.server_public_key = pkt._data[1:] - if DEBUG: - print("Received public key:\n", conn.server_public_key.decode('ascii')) + print("sha256: Requesting server public key") + pkt = _roundtrip(conn, b'\1') - if conn.password: - if not conn.server_public_key: - raise OperationalError("Couldn't receive server's public key") + if pkt.is_extra_auth_data(): + conn.server_public_key = pkt._data[1:] + if DEBUG: + print("Received public key:\n", conn.server_public_key.decode('ascii')) - data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) - else: - data = b'' + if conn.password: + if not conn.server_public_key: + raise OperationalError("Couldn't receive server's public key") - return _roundtrip(conn, data) + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) + else: + data = b'' + + return _roundtrip(conn, data) def caching_sha2_scramble(password, nonce): @@ -106,68 +192,61 @@ def caching_sha2_scramble(password, nonce): return bytes(res) -class CachingSHA2Password(object): +def caching_sha2_password_auth(conn, pkt): + # No password fast path + if not conn.password: + return _roundtrip(conn, b'') - def __init__(self, conn): - self.conn = conn + if pkt.is_auth_switch_request(): + # Try from fast auth + if DEBUG: + print("caching sha2: Trying fast path") + conn.salt = pkt.read_all() + scrambled = caching_sha2_scramble(conn.password, conn.salt) + pkt = _roundtrip(conn, scrambled) + # else: fast auth is tried in initial handshake + + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1] + ) + + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + pkt.advance(1) + n = pkt.read_uint8() + + if n == 3: + if DEBUG: + print("caching sha2: succeeded by fast path.") + pkt = conn._read_packet() + pkt.check_error() # pkt must be OK packet + return pkt - def authenticate(self, pkt): - conn = self.conn + if n != 4: + raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) - # No password fast path - if not conn.password: - return _roundtrip(conn, b'') + if DEBUG: + print("caching sha2: Trying full auth...") - if pkt.is_auth_switch_request(): - # Try from fast auth - if DEBUG: - print("caching sha2: Trying fast path") - conn.salt = pkt.read_all() - scrambled = caching_sha2_scramble(conn.password, conn.salt) - pkt = _roundtrip(conn, scrambled) - # else: fast auth is tried in initial handshake + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: + print("caching sha2: Sending plain password via SSL") + return _roundtrip(conn, conn.password + b'\0') + if not conn.server_public_key: + pkt = _roundtrip(conn, b'\x02') # Request public key if not pkt.is_extra_auth_data(): raise OperationalError( - "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1] + "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] ) - # magic numbers: - # 2 - request public key - # 3 - fast auth succeeded - # 4 - need full auth - - pkt.advance(1) - n = pkt.read_uint8() - - if n == 3: - if DEBUG: - print("caching sha2: succeeded by fast path.") - pkt = conn._read_packet() - pkt.check_error() # pkt must be OK packet - return pkt - - if n != 4: - raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) - + conn.server_public_key = pkt._data[1:] if DEBUG: - print("caching sha2: Trying full auth...") + print(conn.server_public_key.decode('ascii')) - if conn.ssl and conn.server_capabilities & CLIENT.SSL: - if DEBUG: - print("caching sha2: Sending plain password via SSL") - return _roundtrip(conn, conn.password + b'\0') - - if not conn.server_public_key: - pkt = _roundtrip(conn, b'\x02') # Request public key - if not pkt.is_extra_auth_data(): - raise OperationalError( - "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] - ) - - conn.server_public_key = pkt._data[1:] - if DEBUG: - print(conn.server_public_key.decode('ascii')) - - data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) - pkt = _roundtrip(conn, data) + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) + pkt = _roundtrip(conn, data) diff --git a/pymysql/connections.py b/pymysql/connections.py index 39111614..0099d95a 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -6,7 +6,6 @@ from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON import errno -from functools import partial import hashlib import io import os @@ -88,86 +87,11 @@ def _makefile(sock, mode): FIELD_TYPE.VARCHAR, FIELD_TYPE.GEOMETRY]) -sha_new = partial(hashlib.new, 'sha1') -DEFAULT_CHARSET = 'latin1' +DEFAULT_CHARSET = 'latin1' # TODO: change to utf8mb4 MAX_PACKET_LEN = 2**24-1 -SCRAMBLE_LENGTH = 20 - -def _scramble(password, message): - if not password: - return b'' - if DEBUG: print('password=' + str(password)) - stage1 = sha_new(password).digest() - stage2 = sha_new(stage1).digest() - s = sha_new() - s.update(message[:SCRAMBLE_LENGTH]) - s.update(stage2) - result = s.digest() - return _my_crypt(result, stage1) - -def _my_crypt(message1, message2): - length = len(message1) - result = b'' - for i in range_type(length): - x = (struct.unpack('B', message1[i:i+1])[0] ^ - struct.unpack('B', message2[i:i+1])[0]) - result += struct.pack('B', x) - return result - -# old_passwords support ported from libmysql/password.c -SCRAMBLE_LENGTH_323 = 8 - - -class RandStruct_323(object): - def __init__(self, seed1, seed2): - self.max_value = 0x3FFFFFFF - self.seed1 = seed1 % self.max_value - self.seed2 = seed2 % self.max_value - - def my_rnd(self): - self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value - self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value - return float(self.seed1) / float(self.max_value) - - -def _scramble_323(password, message): - hash_pass = _hash_password_323(password) - hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) - hash_pass_n = struct.unpack(">LL", hash_pass) - hash_message_n = struct.unpack(">LL", hash_message) - - rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0], - hash_pass_n[1] ^ hash_message_n[1]) - outbuf = io.BytesIO() - for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))): - outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) - extra = int2byte(int(rand_st.my_rnd() * 31)) - out = outbuf.getvalue() - outbuf = io.BytesIO() - for c in out: - outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) - return outbuf.getvalue() - - -def _hash_password_323(password): - nr = 1345345333 - add = 7 - nr2 = 0x12345671 - - # x in py3 is numbers, p27 is chars - for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: - nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF - nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF - add = (add + c) & 0xFFFFFFFF - - r1 = nr & ((1 << 31) - 1) # kill sign bits - r2 = nr2 & ((1 << 31) - 1) - return struct.pack(">LL", r1, r2) - - def pack_int24(n): return struct.pack(' Date: Sat, 23 Jun 2018 02:51:41 +0900 Subject: [PATCH 11/20] Fix some errors --- pymysql/_auth.py | 4 ++-- pymysql/connections.py | 47 ++++++++++++++++++++---------------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index ecba9d65..7425c7f8 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -172,7 +172,7 @@ def sha256_password_auth(conn, pkt): return _roundtrip(conn, data) -def caching_sha2_scramble(password, nonce): +def scramble_caching_sha2(password, nonce): # (bytes, bytes) -> bytes """Scramble algorithm used in cached_sha2_password fast path. @@ -202,7 +202,7 @@ def caching_sha2_password_auth(conn, pkt): if DEBUG: print("caching sha2: Trying fast path") conn.salt = pkt.read_all() - scrambled = caching_sha2_scramble(conn.password, conn.salt) + scrambled = scramble_caching_sha2(conn.password, conn.salt) pkt = _roundtrip(conn, scrambled) # else: fast auth is tried in initial handshake diff --git a/pymysql/connections.py b/pymysql/connections.py index 0099d95a..df408035 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -6,7 +6,6 @@ from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON import errno -import hashlib import io import os import socket @@ -92,9 +91,11 @@ def _makefile(sock, mode): MAX_PACKET_LEN = 2**24-1 + def pack_int24(n): return struct.pack('`_ in the specification. - + :raise Error: If the connection is already closed. """ if self._closed: @@ -380,7 +376,7 @@ def _force_close(self): if self._sock: try: self._sock.close() - except: + except: # noqa pass self._sock = None self._rfile = None @@ -419,7 +415,7 @@ def begin(self): def commit(self): """ Commit changes to stable storage. - + See `Connection.commit() `_ in the specification. """ @@ -429,7 +425,7 @@ def commit(self): def rollback(self): """ Roll back the current transaction. - + See `Connection.rollback() `_ in the specification. """ @@ -446,7 +442,7 @@ def show_warnings(self): def select_db(self, db): """ Set current db. - + :param db: The name of the db. """ self._execute_command(COMMAND.COM_INIT_DB, db) @@ -488,7 +484,7 @@ def _quote_bytes(self, s): def cursor(self, cursor=None): """ Create a new cursor to execute queries with. - + :param cursor: The type of cursor to create; one of :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`, or :py:class:`SSDictCursor`. None means use Cursor. @@ -536,7 +532,7 @@ def kill(self, thread_id): def ping(self, reconnect=True): """ Check if the server is alive. - + :param reconnect: If the connection is closed, reconnect. :raise Error: If the connection is closed and reconnect=False. """ @@ -618,7 +614,7 @@ def connect(self, sock=None): if sock is not None: try: sock.close() - except: + except: # noqa pass if isinstance(e, (OSError, IOError, socket.error)): @@ -745,7 +741,6 @@ def _execute_command(self, command, sql): :raise InterfaceError: If the connection is closed. :raise ValueError: If no username was specified. """ - if not self._sock: raise err.InterfaceError("(0, '')") @@ -809,14 +804,16 @@ def _request_authentication(self): plugin_name = None if self._auth_plugin_name in ('', 'mysql_native_password'): - authresp = _scramble(self.password, self.salt) + authresp = _auth.scramble_native_password(self.password, self.salt) elif self._auth_plugin_name == 'caching_sha2_password': plugin_name = b'caching_sha2_password' if self.password: - print("caching_sha2: trying fast path") - authresp = _auth._scramble_sha256_password(self.password, self.salt) + if DEBUG: + print("caching_sha2: trying fast path") + authresp = _auth.scramble_caching_sha2(self.password, self.salt) else: - print("caching_sha2: without password") + if DEBUG: + print("caching_sha2: empty password") elif self._auth_plugin_name == 'sha256_password': plugin_name = b'sha256_password' if self.ssl and self.server_capabilities & CLIENT.SSL: @@ -824,7 +821,7 @@ def _request_authentication(self): elif self.password: authresp = b'\1' # request public key else: - authresp = b'\0' # skip + authresp = b'\0' # empty password if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += lenenc_int(len(authresp)) + authresp @@ -875,7 +872,7 @@ def _request_authentication(self): elif self._auth_plugin_name == b"sha256_password": auth_packet = _auth.sha256_password_auth(self, auth_packet) else: - raise OperationalError("Received extra packet for auth method %r", self._auth_plugin_name) + raise err.OperationalError("Received extra packet for auth method %r", self._auth_plugin_name) if DEBUG: print("Succeed to auth") @@ -886,8 +883,8 @@ def _process_auth(self, plugin_name, auth_packet): return handler.authenticate(auth_packet) except AttributeError: if plugin_name != b'dialog': - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r missing authenticate method" % (plugin_name, plugin_class)) + raise err.OperationalError(2059, "Authentication plugin '%s'" + " not loaded: - %r missing authenticate method" % (plugin_name, type(handler))) if plugin_name == b"caching_sha2_password": data = _auth.caching_sha2_password_auth(self, auth_packet) elif plugin_name == b"sha256_password": @@ -934,7 +931,7 @@ def _process_auth(self, plugin_name, auth_packet): pkt = self._read_packet() pkt.check_error() return pkt - + def _get_auth_plugin_handler(self, plugin_name): plugin_class = self._auth_plugin_map.get(plugin_name) if not plugin_class and isinstance(plugin_name, bytes): @@ -943,7 +940,7 @@ def _get_auth_plugin_handler(self, plugin_name): try: handler = plugin_class(self) except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ + raise err.OperationalError(2059, "Authentication plugin '%s'" " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) else: handler = None From 0555496aa0e1bc2eeb184c07b317e38ffa56cc6c Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 17:48:03 +0900 Subject: [PATCH 12/20] fix --- pymysql/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index df408035..f23f58d5 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -835,7 +835,7 @@ def _request_authentication(self): self.db = self.db.encode(self.encoding) data += self.db + b'\0' - if self.server_capabilities & CLIENT.PLUGIN_AUTH: + if plugin_name is not None and self.server_capabilities & CLIENT.PLUGIN_AUTH: data += plugin_name + b'\0' if self.server_capabilities & CLIENT.CONNECT_ATTRS: From 311487f2dd29e9a924b512ffd6f4cf4e409e0005 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 17:49:16 +0900 Subject: [PATCH 13/20] fix --- pymysql/_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index 7425c7f8..38563815 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -40,7 +40,7 @@ def scramble_native_password(password, message): def _my_crypt(message1, message2): length = len(message1) result = b'' - for i in range_type(length): + for i in range(length): x = ( struct.unpack('B', message1[i:i + 1])[0] ^ struct.unpack('B', message2[i:i + 1])[0] @@ -79,7 +79,7 @@ def scramble_old_password(password, message): hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1] ) outbuf = io.BytesIO() - for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))): + for _ in range(min(SCRAMBLE_LENGTH_323, len(message))): outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) extra = int2byte(int(rand_st.my_rnd() * 31)) out = outbuf.getvalue() From 59b748db257bde05d9f9fd0d8e7209a0863ac9a3 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 18:22:59 +0900 Subject: [PATCH 14/20] test --- .travis.yml | 6 +++- .travis/initializedb.sh | 9 ++++++ sha256_test.py | 26 ----------------- tests/__init__.py | 0 tests/test_auth.py | 63 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 27 deletions(-) delete mode 100644 sha256_test.py create mode 100644 tests/__init__.py create mode 100644 tests/test_auth.py diff --git a/.travis.yml b/.travis.yml index e5dc76c6..8d960249 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,13 +35,14 @@ matrix: python: "3.4" - env: - DB=mysql:8.0 + - TEST_AUTH=yes python: "3.7-dev" # different py version from 5.6 and 5.7 as cache seems to be based on py version # http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version # really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't install: - - pip install -U coveralls unittest2 coverage cryptography + - pip install -U coveralls unittest2 coverage cryptography pytest before_script: - ./.travis/initializedb.sh @@ -51,6 +52,9 @@ before_script: script: - coverage run ./runtests.py + - if [ "${TEST_AUTH}" = "yes" ]; + then pytest -v tests; + fi - if [ ! -z "${DB}" ]; then docker logs mysqld; fi diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh index 18c00eca..c149fe8e 100755 --- a/.travis/initializedb.sh +++ b/.travis/initializedb.sh @@ -37,6 +37,15 @@ if [ ! -z "${DB}" ]; then docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}" docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}" docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}" + + # Test user for auth test + mysql -e " + CREATE USER + user_sha256 IDENTIFIED WITH 'sha256_password' BY 'pass_sha256', + nopass_sha256 IDENTIFIED WITH 'sha256_password', + user_caching_sha2 IDENTIFIED WITH 'caching_sha2_password' BY 'pass_caching_sha2' + nopass_caching_sha2 IDENTIFIED WITH 'caching_sha2_password' + PASSWORD EXPIRE NEVER;" else WITH_PLUGIN='' fi diff --git a/sha256_test.py b/sha256_test.py deleted file mode 100644 index 3e6f0dab..00000000 --- a/sha256_test.py +++ /dev/null @@ -1,26 +0,0 @@ -import pymysql - -pymysql.connections.DEBUG = True -pymysql._auth.DEBUG = True - -host="127.0.0.1" -port=3306 - -ssl = {'ca': 'ca.pem', 'check_hostname': False} -#ssl = None - -print("### trying sha2 without password") -con = pymysql.connect(user="user_sha2_nopass", host=host, port=port, ssl=ssl) -print("OK\n\n\n") - -print("### trying sha2 with password") -con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=ssl) -print("OK\n\n\n") - -print("### trying caching sha2 without password") -con = pymysql.connect(user="user_csha2_nopass", host=host, port=port, ssl=ssl) -print("OK\n\n\n") - -print("### trying caching sha2 with password") -con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=ssl) -print("OK\n\n\n") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 00000000..7d857344 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,63 @@ +"""Test for auth methods supported by MySQL 8""" + +import os +import pymysql + +# pymysql.connections.DEBUG = True +# pymysql._auth.DEBUG = True + +host = "127.0.0.1" +port = 3306 + +ca = os.path.expanduser("~/ca.pem") +ssl = {'ca': ca, 'check_hostname': False} + + +def test_sha256_no_password(): + con = pymysql.connect(user="nopass_sha256", host=host, port=port, ssl=None) + con.close() + + +def test_sha256_no_passowrd_ssl(): + con = pymysql.connect(user="nopass_sha256", host=host, port=port, ssl=ssl) + con.close() + + +def test_sha256_password(): + con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=None) + con.close() + + +def test_sha256_password_ssl(): + con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=ssl) + con.close() + + +def test_caching_sha2_no_password(): + con = pymysql.connect(user="nopass_caching_sha2", host=host, port=port, ssl=None) + con.close() + + +def test_caching_sha2_no_password(): + con = pymysql.connect(user="nopass_caching_sha2", host=host, port=port, ssl=ssl) + con.close() + + +def test_caching_sha2_password(): + con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con.close() + + # Fast path of caching sha2 + con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con.query("FLUSH PRIVILEGES") + con.close() + + +def test_caching_sha2_password_ssl(): + con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=ssl) + con.close() + + # Fast path of caching sha2 + con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con.query("FLUSH PRIVILEGES") + con.close() From 69d8512781aa210584fe01e13ce8b2bcfe4822d5 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 19:02:09 +0900 Subject: [PATCH 15/20] fix --- pymysql/connections.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index f23f58d5..44bb5be5 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -865,11 +865,12 @@ def _request_authentication(self): self.write_packet(data) auth_packet = self._read_packet() elif auth_packet.is_extra_auth_data(): - print("received extra data") + if DEBUG: + print("received extra data") # https://dev.mysql.com/doc/internals/en/successful-authentication.html - if self._auth_plugin_name == b"caching_sha2_password": + if self._auth_plugin_name == "caching_sha2_password": auth_packet = _auth.caching_sha2_password_auth(self, auth_packet) - elif self._auth_plugin_name == b"sha256_password": + elif self._auth_plugin_name == "sha256_password": auth_packet = _auth.sha256_password_auth(self, auth_packet) else: raise err.OperationalError("Received extra packet for auth method %r", self._auth_plugin_name) @@ -1018,9 +1019,9 @@ def _get_server_information(self): server_end = data.find(b'\0', i) if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all - self._auth_plugin_name = data[i:].decode('latin1') + self._auth_plugin_name = data[i:].decode('utf-8') else: - self._auth_plugin_name = data[i:server_end].decode('latin1') + self._auth_plugin_name = data[i:server_end].decode('utf-8') def get_server_info(self): return self.server_version From 47839da1ef2a02971e7189ce1c24d49f62a97de2 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 19:14:25 +0900 Subject: [PATCH 16/20] Fix initializedb --- .travis/initializedb.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh index c149fe8e..221f3024 100755 --- a/.travis/initializedb.sh +++ b/.travis/initializedb.sh @@ -39,13 +39,13 @@ if [ ! -z "${DB}" ]; then docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}" # Test user for auth test - mysql -e " + mysql -e ' CREATE USER - user_sha256 IDENTIFIED WITH 'sha256_password' BY 'pass_sha256', - nopass_sha256 IDENTIFIED WITH 'sha256_password', - user_caching_sha2 IDENTIFIED WITH 'caching_sha2_password' BY 'pass_caching_sha2' - nopass_caching_sha2 IDENTIFIED WITH 'caching_sha2_password' - PASSWORD EXPIRE NEVER;" + user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256", + nopass_sha256 IDENTIFIED WITH "sha256_password", + user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2", + nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password" + PASSWORD EXPIRE NEVER;' else WITH_PLUGIN='' fi From 9739ce2440b6cbcfb3d0d7aff2e24289b3d7abd8 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 19:56:32 +0900 Subject: [PATCH 17/20] Disable connect attrs by default --- pymysql/constants/CLIENT.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py index b42f1523..b09412c8 100644 --- a/pymysql/constants/CLIENT.py +++ b/pymysql/constants/CLIENT.py @@ -23,7 +23,8 @@ CAPABILITIES = ( LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS | SECURE_CONNECTION | MULTI_RESULTS - | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) + | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA) + #| PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) # Not done yet HANDLE_EXPIRED_PASSWORDS = 1 << 22 From c42dd4e91c1460798b86b97c4e69b8204fb75d5e Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 21:06:14 +0900 Subject: [PATCH 18/20] fix MariaDB fails --- pymysql/connections.py | 4 ++-- pymysql/constants/CLIENT.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index 44bb5be5..96df4903 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -835,8 +835,8 @@ def _request_authentication(self): self.db = self.db.encode(self.encoding) data += self.db + b'\0' - if plugin_name is not None and self.server_capabilities & CLIENT.PLUGIN_AUTH: - data += plugin_name + b'\0' + if self.server_capabilities & CLIENT.PLUGIN_AUTH: + data += (plugin_name or b'') + b'\0' if self.server_capabilities & CLIENT.CONNECT_ATTRS: connect_attrs = b'' diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py index b09412c8..b42f1523 100644 --- a/pymysql/constants/CLIENT.py +++ b/pymysql/constants/CLIENT.py @@ -23,8 +23,7 @@ CAPABILITIES = ( LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS | SECURE_CONNECTION | MULTI_RESULTS - | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA) - #| PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) + | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) # Not done yet HANDLE_EXPIRED_PASSWORDS = 1 << 22 From 6d993320daba46eb8c0b389c3a8d8d7221a1ae23 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 21:31:52 +0900 Subject: [PATCH 19/20] bugfix --- pymysql/_auth.py | 2 +- pymysql/connections.py | 4 ++-- runtests.py | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pymysql/_auth.py b/pymysql/_auth.py index 38563815..ddf6e4e5 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -121,7 +121,7 @@ def _xor_password(password, salt): salt_len = len(salt) for i in range(len(password_bytes)): password_bytes[i] ^= salt[i % salt_len] - return password_bytes + return bytes(password_bytes) def sha2_rsa_encrypt(password, salt, public_key): diff --git a/pymysql/connections.py b/pymysql/connections.py index 96df4903..14ae76d9 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -887,9 +887,9 @@ def _process_auth(self, plugin_name, auth_packet): raise err.OperationalError(2059, "Authentication plugin '%s'" " not loaded: - %r missing authenticate method" % (plugin_name, type(handler))) if plugin_name == b"caching_sha2_password": - data = _auth.caching_sha2_password_auth(self, auth_packet) + return _auth.caching_sha2_password_auth(self, auth_packet) elif plugin_name == b"sha256_password": - data = _auth.sha256_password_auth(self, auth_packet) + return _auth.sha256_password_auth(self, auth_packet) elif plugin_name == b"mysql_native_password": data = _auth.scramble_native_password(self.password, auth_packet.read_all()) elif plugin_name == b"mysql_old_password": diff --git a/runtests.py b/runtests.py index 00e492b0..ea3d9e8d 100755 --- a/runtests.py +++ b/runtests.py @@ -3,6 +3,10 @@ from pymysql._compat import PYPY, JYTHON, IRONPYTHON +#import pymysql +#pymysql.connections.DEBUG = True +#pymysql._auth.DEBUG = True + if not (PYPY or JYTHON or IRONPYTHON): import atexit import gc From 40adcf5cff4ecdebde5b23d8f4c45739576c18c9 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 26 Jun 2018 21:33:09 +0900 Subject: [PATCH 20/20] Add GRANT for "FLUSH PRIVILEGES" --- .travis/initializedb.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh index 221f3024..d9897e49 100755 --- a/.travis/initializedb.sh +++ b/.travis/initializedb.sh @@ -46,6 +46,7 @@ if [ ! -z "${DB}" ]; then user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2", nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password" PASSWORD EXPIRE NEVER;' + mysql -e 'GRANT RELOAD ON *.* TO user_caching_sha2;' else WITH_PLUGIN='' fi